import glob import os import shutil import sys from torch_geometric.data.makedirs import makedirs from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.utils.io import string_to_python from explaining_framework.utils.io import (obj_config_to_str, read_yaml, write_yaml) # class BaseConfigGenerator(object): # def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str): # self.dataset_name=dataset_name # self.explainer_name=explainer_name # self.explainer_config = explainer_config # self.model_folder = model_folder # class ExplainingConfigGenerator(object): # def __init__( # self, # dataset_name: str, # explainer_name: str, # model_folder: str, # explainer_config: str = "default", # item: list = None, # ckpt: str = "best", # ): # self.dataset_name = dataset_name # self.explainer_name = explainer_name # self.explainer_config = explainer_config # self.item = item # self.ckpt = ckpt # self.model_folder = model_folder def explaining_conf( dataset: str, model_kind: str, explainer: str, explainer_config: str = "default" ): explaining_cfg = {} explaining_cfg[ "cfg_dest" ] = f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" explaining_cfg["dataset"] = {} explaining_cfg["dataset"]["name"] = dataset explaining_cfg["explainer"] = {} explaining_cfg["explainer"]["cfg"] = explainer_config explaining_cfg["explainer"]["name"] = explainer explaining_cfg["explainer"]["force"] = False explaining_cfg["explanation_type"] = "phenomenon" explaining_cfg["model"] = {} explaining_cfg["model"]["ckpt"] = model_kind explaining_cfg["model"]["path"] = sys.argv[1] # explaining_cfg['out_dir']='./explanation' explaining_cfg["threshold"] = {} explaining_cfg["threshold"]["value"] = {} explaining_cfg["threshold"]["value"]["hard"] = [0, 0.1, 0.3, 0.5, 0.7, 0.9] explaining_cfg["threshold"]["value"]["topk"] = [2, 5, 10, 20, 50] return explaining_cfg def explainer_conf(explainer: str, **kwargs): explaining_cfg = {} if explainer == "SCGNN": explaining_cfg["target_baseline"] = kwargs.get("target_baseline") explaining_cfg["depth"] = "all" explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm") explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm") elif explainer == "EIXGNN": explaining_cfg["L"] = kwargs.get("L") explaining_cfg["p"] = kwargs.get("p") explaining_cfg["importance_sampling_strategy"] = kwargs.get( "importance_sampling_strategy" ) explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity") explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity") explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx") return explaining_cfg if "__main__" == __name__: config_folder = os.path.abspath( os.path.join(os.path.abspath(os.path.dirname(__name__)), "configs") ) makedirs(config_folder) explaining_folder = os.path.join(config_folder, "explaining") makedirs(explaining_folder) explainer_folder = os.path.join(config_folder, "explainer") makedirs(explainer_folder) # TODO Make a single list for all dataset name or explaining method name, etc DATASET = [ "CIFAR10", "TRIANGLES", "COLORS-3", "REDDIT-BINARY", "REDDIT-MULTI-5K", "REDDIT-MULTI-12K", "COLLAB", "DBLP_v1", "COIL-DEL", "COIL-RAG", "Fingerprint", "Letter-high", "Letter-low", "Letter-med", "MSRC_9", "MSRC_21", "MSRC_21C", "DD", "ENZYMES", "PROTEINS", "QM9", "MUTAG", "Mutagenicity", "AIDS", "PATTERN", "CLUSTER", "MNIST", "CIFAR10", "TSP", "CSL", "KarateClub", "CS", "Physics", "BBBP", "Tox21", "HIV", "PCBA", "MUV", "BACE", "SIDER", "ClinTox", "AIFB", "AM", "MUTAG", "BGS", "FAUST", "DynamicFAUST", "ShapeNet", "ModelNet10", "ModelNet40", "PascalVOC-SP", "COCO-SP", ] EXPLAINER = [ "CAM", "GradCAM", # "GNN_LRP", "GradExplainer", "GuidedBackPropagation", "IntegratedGradients", # "PGExplainer", "PGMExplainer", "RandomExplainer", # "SubgraphX", "GraphMASK", "GNNExplainer", "EIXGNN", "SCGNN", ] for dataset_name in DATASET: for model_kind in ["best", "worst"]: for explainer_name in EXPLAINER: explainer_path = ["default"] explainer_config = [None] if explainer_name == "EIXGNN": explainer_config = [] explainer_path = [] for imp_str in ["node", "neighborhood", "no_prior"]: for dom_sim in ["relative_edge_density"]: for sig_sim in ["KL", "KL_sym"]: for sh_val in [1000]: for L in [5, 10, 15, 20, 30, 50]: for p in [0.2, 0.3, 0.5, 0.7]: config = explainer_conf( "EIXGNN", importance_sampling_strategy=imp_str, domain_similarity=dom_sim, signal_similarity=sig_sim, shapley_value_approx=sh_val, L=L, p=p, ) path_explainer = os.path.join( explainer_folder, "EIXGNN_" + obj_config_to_str(config) + ".yaml", ) explainer_path.append(path_explainer) explainer_config.append(config) write_yaml(config, path_explainer) if explainer_name == "SCGNN": explainer_config = [] explainer_path = [] for target_baseline in [None, "inference"]: for depth in ["all"]: for sc_map in [True, False]: for in_map in [True, False]: config = explainer_conf( "SCGNN", target_baseline=target_baseline, depth=depth, score_map_norm=sc_map, interest_map_norm=in_map, ) path_explainer = os.path.join( explainer_folder, "SCGNN_" + obj_config_to_str(config) + ".yaml", ) explainer_path.append(path_explainer) explainer_config.append(config) write_yaml(config, path_explainer) for explainer_p, explainer_c in zip(explainer_path, explainer_config): explaining_cfg = explaining_conf( dataset=dataset_name, model_kind=model_kind, explainer=explainer_name, explainer_config=explainer_p, ) if explainer_c is None: PATH = os.path.join( explaining_folder + "/" + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" ) else: PATH = os.path.join( explaining_folder + "/" + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml" ) write_yaml(explaining_cfg, PATH) os.makedirs(explaining_folder + "/0", exist_ok=True) os.makedirs(explaining_folder + "/1", exist_ok=True) a = glob.glob(explaining_folder + "/*.yaml") for path in a[:8050]: basename = os.path.basename(path) dirname = os.path.dirname(path) os.rename(path, dirname + "/0/" + basename) for path in a[8050:]: basename = os.path.basename(path) dirname = os.path.dirname(path) os.rename(path, dirname + "/1/" + basename)