From 45ea7e42552266571c691a58bfcbae270b5a806f Mon Sep 17 00:00:00 2001 From: araison Date: Mon, 16 Jan 2023 01:41:24 +0100 Subject: [PATCH] Adding config_gen.py --- config_gen.py | 238 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 config_gen.py diff --git a/config_gen.py b/config_gen.py new file mode 100644 index 0000000..43633d3 --- /dev/null +++ b/config_gen.py @@ -0,0 +1,238 @@ +import glob +import os +import shutil +import sys + +from torch_geometric.data.makedirs import makedirs +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.utils.io import string_to_python + +from explaining_framework.utils.io import (obj_config_to_str, read_yaml, + write_yaml) + +# class BaseConfigGenerator(object): +# def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str): +# self.dataset_name=dataset_name +# self.explainer_name=explainer_name +# self.explainer_config = explainer_config +# self.model_folder = model_folder + + +# class ExplainingConfigGenerator(object): +# def __init__( +# self, +# dataset_name: str, +# explainer_name: str, +# model_folder: str, +# explainer_config: str = "default", +# item: list = None, +# ckpt: str = "best", +# ): +# self.dataset_name = dataset_name +# self.explainer_name = explainer_name +# self.explainer_config = explainer_config +# self.item = item +# self.ckpt = ckpt +# self.model_folder = model_folder + + +def explaining_conf( + dataset: str, model_kind: str, explainer: str, explainer_config: str = "default" +): + explaining_cfg = {} + explaining_cfg[ + "cfg_dest" + ] = f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" + explaining_cfg["dataset"] = {} + explaining_cfg["dataset"]["name"] = dataset + explaining_cfg["explainer"] = {} + explaining_cfg["explainer"]["cfg"] = explainer_config + explaining_cfg["explainer"]["name"] = explainer + explaining_cfg["explainer"]["force"] = True + explaining_cfg["explanation_type"] = "phenomenon" + explaining_cfg["model"] = {} + explaining_cfg["model"]["ckpt"] = model_kind + explaining_cfg["model"]["path"] = sys.argv[1] + # explaining_cfg['out_dir']='./explanation' + explaining_cfg["threshold"] = {} + explaining_cfg["threshold"]["value"] = {} + explaining_cfg["threshold"]["value"]["hard"] = [0, 0.1, 0.3, 0.5, 0.7, 0.9] + explaining_cfg["threshold"]["value"]["topk"] = [2, 5, 10, 20, 50] + return explaining_cfg + + +def explainer_conf(explainer: str, **kwargs): + explaining_cfg = {} + if explainer == "SCGNN": + explaining_cfg["target_baseline"] = kwargs.get("target_baseline") + explaining_cfg["depth"] = "all" + explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm") + explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm") + elif explainer == "EIXGNN": + explaining_cfg["L"] = kwargs.get("L") + explaining_cfg["p"] = kwargs.get("p") + explaining_cfg["importance_sampling_strategy"] = kwargs.get( + "importance_sampling_strategy" + ) + explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity") + explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity") + explaining_cfg["shap_val_approx"] = kwargs.get("shap_val_approx") + return explaining_cfg + + +if "__main__" == __name__: + config_folder = os.path.abspath( + os.path.join(os.path.abspath(os.path.dirname(__name__)), "configs") + ) + makedirs(config_folder) + explaining_folder = os.path.join(config_folder, "explaining") + makedirs(explaining_folder) + explainer_folder = os.path.join(config_folder, "explainer") + makedirs(explainer_folder) + + # TODO Make a single list for all dataset name or explaining method name, etc + + DATASET = [ + "CIFAR10", + "TRIANGLES", + "COLORS-3", + "REDDIT-BINARY", + "REDDIT-MULTI-5K", + "REDDIT-MULTI-12K", + "COLLAB", + "DBLP_v1", + "COIL-DEL", + "COIL-RAG", + "Fingerprint", + "Letter-high", + "Letter-low", + "Letter-med", + "MSRC_9", + "MSRC_21", + "MSRC_21C", + "DD", + "ENZYMES", + "PROTEINS", + "QM9", + "MUTAG", + "Mutagenicity", + "AIDS", + "PATTERN", + "CLUSTER", + "MNIST", + "CIFAR10", + "TSP", + "CSL", + "KarateClub", + "CS", + "Physics", + "BBBP", + "Tox21", + "HIV", + "PCBA", + "MUV", + "BACE", + "SIDER", + "ClinTox", + "AIFB", + "AM", + "MUTAG", + "BGS", + "FAUST", + "DynamicFAUST", + "ShapeNet", + "ModelNet10", + "ModelNet40", + "PascalVOC-SP", + "COCO-SP", + ] + EXPLAINER = [ + "CAM", + "GradCAM", + # "GNN_LRP", + "GradExplainer", + "GuidedBackPropagation", + "IntegratedGradients", + # "PGExplainer", + "PGMExplainer", + "RandomExplainer", + # "SubgraphX", + "GraphMASK", + "GNNExplainer", + "EIXGNN", + "SCGNN", + ] + + for dataset_name in DATASET: + for model_kind in ["best", "worst"]: + for explainer_name in EXPLAINER: + explainer_path = ["default"] + explainer_config = [None] + if explainer_name == "EIXGNN": + explainer_config = [] + explainer_path = [] + for imp_str in ["node", "neighborhood", "no_prior"]: + for dom_sim in ["relative_edge_density"]: + for sig_sim in ["KL", "KL_sym"]: + for sh_val in [1000]: + for L in [5, 10, 15, 20, 30, 50]: + for p in [0.2, 0.3, 0.5, 0.7]: + config = explainer_conf( + "EIXGNN", + importance_sampling_strategy=imp_str, + domain_similarity=dom_sim, + signal_similarity=sig_sim, + shap_val_approx=sh_val, + L=L, + p=p, + ) + path_explainer = os.path.join( + explainer_folder, + "EIXGNN_" + + obj_config_to_str(config) + + ".yaml", + ) + explainer_path.append(path_explainer) + explainer_config.append(config) + write_yaml(config, path_explainer) + if explainer_name == "SCGNN": + explainer_config = [] + explainer_path = [] + for target_baseline in [None, "inference"]: + for depth in ["all"]: + for sc_map in [True, False]: + for in_map in [True, False]: + config = explainer_conf( + "SCGNN", + target_baseline=target_baseline, + depth=depth, + score_map_norm=sc_map, + interest_map_norm=in_map, + ) + path_explainer = os.path.join( + explainer_folder, + "SCGNN_" + obj_config_to_str(config) + ".yaml", + ) + explainer_path.append(path_explainer) + explainer_config.append(config) + write_yaml(config, path_explainer) + for explainer_p, explainer_c in zip(explainer_path, explainer_config): + explaining_cfg = explaining_conf( + dataset=dataset_name, + model_kind=model_kind, + explainer=explainer_name, + explainer_config=explainer_p, + ) + if explainer_c is None: + PATH = os.path.join( + explaining_folder + + "/" + + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" + ) + else: + PATH = os.path.join( + explaining_folder + + "/" + + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml" + ) + write_yaml(explaining_cfg, PATH)