import glob import os import shutil import sys from torch_geometric.data.makedirs import makedirs from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.utils.io import string_to_python from explaining_framework.utils.io import (obj_config_to_str, read_yaml, write_yaml) def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i : i + n] # class BaseConfigGenerator(object): # def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str): # self.dataset_name=dataset_name # self.explainer_name=explainer_name # self.explainer_config = explainer_config # self.model_folder = model_folder # class ExplainingConfigGenerator(object): # def __init__( # self, # dataset_name: str, # explainer_name: str, # model_folder: str, # explainer_config: str = "default", # item: list = None, # ckpt: str = "best", # ): # self.dataset_name = dataset_name # self.explainer_name = explainer_name # self.explainer_config = explainer_config # self.item = item # self.ckpt = ckpt # self.model_folder = model_folder def explaining_conf( dataset: str, model_kind: str, explainer: str, explainer_config: str = "default" ): explaining_cfg = {} explaining_cfg[ "cfg_dest" ] = f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" explaining_cfg["dataset"] = {} explaining_cfg["dataset"]["name"] = dataset explaining_cfg["explainer"] = {} explaining_cfg["explainer"]["cfg"] = explainer_config explaining_cfg["explainer"]["name"] = explainer explaining_cfg["explainer"]["force"] = False explaining_cfg["explanation_type"] = "phenomenon" explaining_cfg["model"] = {} explaining_cfg["model"]["ckpt"] = model_kind explaining_cfg["model"]["path"] = sys.argv[1] # explaining_cfg['out_dir']='./explanation' explaining_cfg["threshold"] = {} explaining_cfg["threshold"]["value"] = {} explaining_cfg["threshold"]["value"]["hard"] = [0, 0.1, 0.3, 0.5, 0.7, 0.9] explaining_cfg["threshold"]["value"]["topk"] = [2, 5, 10, 20, 50] return explaining_cfg def explainer_conf(explainer: str, **kwargs): explaining_cfg = {} if explainer == "SCGNN": explaining_cfg["target_baseline"] = kwargs.get("target_baseline") explaining_cfg["depth"] = "all" explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm") explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm") elif explainer == "EIXGNN": explaining_cfg["L"] = kwargs.get("L") explaining_cfg["p"] = kwargs.get("p") explaining_cfg["importance_sampling_strategy"] = kwargs.get( "importance_sampling_strategy" ) explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity") explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity") explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx") elif explainer == "XGWT": explaining_cfg["wav_approx"] = kwargs.get("wav_approx") explaining_cfg["wav_passband"] = kwargs.get("wav_passband") explaining_cfg["wav_norm"] = kwargs.get("wav_norm") explaining_cfg["candidates"] = kwargs.get("candidates") explaining_cfg["samples"] = kwargs.get("samples") explaining_cfg["c_proc"] = kwargs.get("c_proc") explaining_cfg["pred_thres_strat"] = kwargs.get("pred_thres_strat") explaining_cfg["CI_thres"] = kwargs.get("CI_thres") explaining_cfg["mix"] = kwargs.get("mix") explaining_cfg["scales"] = kwargs.get("scales") explaining_cfg["pred_thres"] = kwargs.get("pred_thres") explaining_cfg["incl_prob"] = kwargs.get("incl_prob") explaining_cfg["top_k"] = kwargs.get("top_k") explaining_cfg["get_DAG"] = kwargs.get("get_DAG") return explaining_cfg if "__main__" == __name__: config_folder = os.path.abspath( os.path.join(os.path.abspath(os.path.dirname(__name__)), "configs") ) makedirs(config_folder) explaining_folder = os.path.join(config_folder, "explaining") makedirs(explaining_folder) explainer_folder = os.path.join(config_folder, "explainer") makedirs(explainer_folder) # TODO Make a single list for all dataset name or explaining method name, etc DATASET = [ "CIFAR10", "TRIANGLES", "COLORS-3", "REDDIT-BINARY", "REDDIT-MULTI-5K", "REDDIT-MULTI-12K", "COLLAB", "DBLP_v1", "COIL-DEL", "COIL-RAG", "Fingerprint", "Letter-high", "Letter-low", "Letter-med", "MSRC_9", "MSRC_21", "MSRC_21C", "DD", "ENZYMES", "PROTEINS", "QM9", "MUTAG", "Mutagenicity", "AIDS", "PATTERN", "CLUSTER", "MNIST", "CIFAR10", "TSP", "CSL", "KarateClub", "CS", "Physics", "BBBP", "Tox21", "HIV", "PCBA", "MUV", "BACE", "SIDER", "ClinTox", "AIFB", "AM", "MUTAG", "BGS", "FAUST", "DynamicFAUST", "ShapeNet", "ModelNet10", "ModelNet40", "PascalVOC-SP", "COCO-SP", ] EXPLAINER = [ "CAM", "GradCAM", # "GNN_LRP", "GradExplainer", "GuidedBackPropagation", "IntegratedGradients", # "PGExplainer", "PGMExplainer", "RandomExplainer", # "SubgraphX", "GraphMASK", "GNNExplainer", "EIXGNN", "SCGNN", "XGWT", ] for dataset_name in DATASET: for model_kind in ["best", "worst"]: for explainer_name in EXPLAINER: explainer_path = ["default"] explainer_config = [None] if explainer_name == "EIXGNN": explainer_config = [] explainer_path = [] for imp_str in ["node", "neighborhood", "no_prior"]: for dom_sim in ["relative_edge_density"]: for sig_sim in ["KL", "KL_sym"]: for sh_val in [1000]: for L in [5, 10, 15, 20, 30, 50]: for p in [0.2, 0.3, 0.5, 0.7]: config = explainer_conf( "EIXGNN", importance_sampling_strategy=imp_str, domain_similarity=dom_sim, signal_similarity=sig_sim, shapley_value_approx=sh_val, L=L, p=p, ) path_explainer = os.path.join( explainer_folder, "EIXGNN_" + obj_config_to_str(config) + ".yaml", ) explainer_path.append(path_explainer) if os.path.exists(path_explainer): continue explainer_config.append(config) write_yaml(config, path_explainer) if explainer_name == "SCGNN": explainer_config = [] explainer_path = [] for target_baseline in [None, "inference"]: for depth in ["all"]: for sc_map in [True, False]: for in_map in [True, False]: config = explainer_conf( "SCGNN", target_baseline=target_baseline, depth=depth, score_map_norm=sc_map, interest_map_norm=in_map, ) path_explainer = os.path.join( explainer_folder, "SCGNN_" + obj_config_to_str(config) + ".yaml", ) explainer_path.append(path_explainer) explainer_config.append(config) if os.path.exists(path_explainer): continue write_yaml(config, path_explainer) if explainer_name == "XGWT": explainer_config = [] explainer_path = [] for wav_approx in [False]: for wav_passband in ["heat"]: for wav_norm in [True]: for candidates in [10, 15, 30, 50]: for samples in [10, 25, 50]: for c_proc in ["auto"]: for pred_thres_strat in ["regular"]: for CI_thres in [0.05]: for mix in ["uniform"]: for scales in [ [2], [3], [5], [9], [2, 3, 5], [2, 3, 5, 9], [5, 9], [2, 3], ]: for pred_thres in [ 0.1, 0.25, 0.5, ]: for incl_prob in [ 0.2, 0.4, 0.6, ]: for top_k in [ 2, 5, 10, ]: for get_DAG in [ False ]: config = explainer_conf( "XGWT", wav_approx=wav_approx, wav_passband=wav_passband, wav_norm=wav_norm, candidates=candidates, samples=samples, c_proc=c_proc, pred_thres_strat=pred_thres_strat, CI_thres=CI_thres, mix=mix, scales=scales, pred_thres=pred_thres, incl_prob=incl_prob, top_k=top_k, get_DAG=get_DAG, ) path_explainer = os.path.join( explainer_folder, "XGWT_" + obj_config_to_str( config ) + ".yaml", ) explainer_path.append( path_explainer ) explainer_config.append( config ) if os.path.exists( path_explainer ): continue write_yaml( config, path_explainer, ) for explainer_p, explainer_c in zip(explainer_path, explainer_config): explaining_cfg = explaining_conf( dataset=dataset_name, model_kind=model_kind, explainer=explainer_name, explainer_config=explainer_p, ) if explainer_c is None: PATH = os.path.join( explaining_folder + "/" + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" ) else: PATH = os.path.join( explaining_folder + "/" + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml" ) write_yaml(explaining_cfg, PATH) a = sorted(glob.glob(explaining_folder + "/*.yaml")) num_GPU = 4 for i in range(num_GPU): os.makedirs(explaining_folder + f"/{i}", exist_ok=True) split_size = int(len(a) / num_GPU) data = chunks(a, split_size) for i, d in enumerate(data): for path in d: basename = os.path.basename(path) dirname = os.path.dirname(path) os.rename(path, dirname + f"/{i}/" + basename)