diff --git a/config_gen.py b/config_gen.py index b1d05ba..52bc21e 100644 --- a/config_gen.py +++ b/config_gen.py @@ -10,6 +10,13 @@ from torch_geometric.graphgym.utils.io import string_to_python from explaining_framework.utils.io import (obj_config_to_str, read_yaml, write_yaml) + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + # class BaseConfigGenerator(object): # def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str): # self.dataset_name=dataset_name @@ -68,6 +75,7 @@ def explainer_conf(explainer: str, **kwargs): explaining_cfg["depth"] = "all" explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm") explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm") + elif explainer == "EIXGNN": explaining_cfg["L"] = kwargs.get("L") explaining_cfg["p"] = kwargs.get("p") @@ -77,6 +85,23 @@ def explainer_conf(explainer: str, **kwargs): explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity") explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity") explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx") + + elif explainer == "XGWT": + explaining_cfg["wav_approx"] = kwargs.get("wav_approx") + explaining_cfg["wav_passband"] = kwargs.get("wav_passband") + explaining_cfg["wav_norm"] = kwargs.get("wav_norm") + explaining_cfg["candidates"] = kwargs.get("candidates") + explaining_cfg["samples"] = kwargs.get("samples") + explaining_cfg["c_proc"] = kwargs.get("c_proc") + explaining_cfg["pred_thres_strat"] = kwargs.get("pred_thres_strat") + explaining_cfg["CI_thres"] = kwargs.get("CI_thres") + explaining_cfg["mix"] = kwargs.get("mix") + explaining_cfg["scales"] = kwargs.get("scales") + explaining_cfg["pred_thres"] = kwargs.get("pred_thres") + explaining_cfg["incl_prob"] = kwargs.get("incl_prob") + explaining_cfg["top_k"] = kwargs.get("top_k") + explaining_cfg["get_DAG"] = kwargs.get("get_DAG") + return explaining_cfg @@ -161,6 +186,7 @@ if "__main__" == __name__: "GNNExplainer", "EIXGNN", "SCGNN", + "XGWT", ] for dataset_name in DATASET: @@ -193,6 +219,8 @@ if "__main__" == __name__: + ".yaml", ) explainer_path.append(path_explainer) + if os.path.exists(path_explainer): + continue explainer_config.append(config) write_yaml(config, path_explainer) if explainer_name == "SCGNN": @@ -215,7 +243,89 @@ if "__main__" == __name__: ) explainer_path.append(path_explainer) explainer_config.append(config) + if os.path.exists(path_explainer): + continue write_yaml(config, path_explainer) + if explainer_name == "XGWT": + explainer_config = [] + explainer_path = [] + for wav_approx in [False]: + for wav_passband in ["heat"]: + for wav_norm in [True]: + for candidates in [10, 15, 30, 50]: + for samples in [10, 25, 50]: + for c_proc in ["auto"]: + for pred_thres_strat in ["regular"]: + for CI_thres in [0.05]: + for mix in ["uniform"]: + for scales in [ + [2], + [3], + [5], + [9], + [2, 3, 5], + [2, 3, 5, 9], + [5, 9], + [2, 3], + ]: + for pred_thres in [ + 0.1, + 0.25, + 0.5, + ]: + for incl_prob in [ + 0.2, + 0.4, + 0.6, + ]: + for top_k in [ + 2, + 5, + 10, + ]: + for get_DAG in [ + False + ]: + config = explainer_conf( + "XGWT", + wav_approx=wav_approx, + wav_passband=wav_passband, + wav_norm=wav_norm, + candidates=candidates, + samples=samples, + c_proc=c_proc, + pred_thres_strat=pred_thres_strat, + CI_thres=CI_thres, + mix=mix, + scales=scales, + pred_thres=pred_thres, + incl_prob=incl_prob, + top_k=top_k, + get_DAG=get_DAG, + ) + path_explainer = os.path.join( + explainer_folder, + "XGWT_" + + obj_config_to_str( + config + ) + + ".yaml", + ) + explainer_path.append( + path_explainer + ) + explainer_config.append( + config + ) + if os.path.exists( + path_explainer + ): + continue + write_yaml( + config, + path_explainer, + ) + for explainer_p, explainer_c in zip(explainer_path, explainer_config): explaining_cfg = explaining_conf( dataset=dataset_name, @@ -236,16 +346,15 @@ if "__main__" == __name__: + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml" ) write_yaml(explaining_cfg, PATH) - os.makedirs(explaining_folder + "/0", exist_ok=True) - os.makedirs(explaining_folder + "/1", exist_ok=True) - a = glob.glob(explaining_folder + "/*.yaml") + a = sorted(glob.glob(explaining_folder + "/*.yaml")) - for path in a[:8050]: - basename = os.path.basename(path) - dirname = os.path.dirname(path) - os.rename(path, dirname + "/0/" + basename) - - for path in a[8050:]: - basename = os.path.basename(path) - dirname = os.path.dirname(path) - os.rename(path, dirname + "/1/" + basename) + num_GPU = 4 + for i in range(num_GPU): + os.makedirs(explaining_folder + f"/{i}", exist_ok=True) + split_size = int(len(a) / num_GPU) + data = chunks(a, split_size) + for i, d in enumerate(data): + for path in d: + basename = os.path.basename(path) + dirname = os.path.dirname(path) + os.rename(path, dirname + f"/{i}/" + basename) diff --git a/explaining_framework/config/explainer_config/xgwt_config.py b/explaining_framework/config/explainer_config/xgwt_config.py index a62b257..96c45a3 100644 --- a/explaining_framework/config/explainer_config/xgwt_config.py +++ b/explaining_framework/config/explainer_config/xgwt_config.py @@ -38,17 +38,18 @@ def set_xgwt_cfg(xgwt_cfg): xgwt_cfg.wav_approx = False xgwt_cfg.wav_passband = "heat" - xgwt_cfg.wav_normalization = True - xgwt_cfg.num_candidates = 30 - xgwt_cfg.num_samples = 10 - xgwt_cfg.c_procedure = "auto" + xgwt_cfg.wav_norm = True + xgwt_cfg.candidates = 30 + xgwt_cfg.samples = 10 + xgwt_cfg.c_proc = "auto" xgwt_cfg.pred_thres_strat = "regular" - xgwt_cfg.CI_threshold = 0.05 - xgwt_cfg.mixing = "uniform" + xgwt_cfg.CI_thres = 0.05 + xgwt_cfg.mix = "uniform" xgwt_cfg.scales = [3] xgwt_cfg.pred_thres = 0.1 xgwt_cfg.incl_prob = 0.4 xgwt_cfg.top_k = 5 + xgwt_cfg.get_DAG = False def assert_cfg(xgwt_cfg): diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 70bd2da..d32e2ec 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -300,16 +300,17 @@ class ExplainingOutline(object): explaining_algorithm = XGWT( wav_approx=self.explainer_cfg.wav_approx, wav_passband=self.explainer_cfg.wav_passband, - wav_normalization=self.explainer_cfg.wav_normalization, - num_candidates=self.explainer_cfg.num_candidates, - num_samples=self.explainer_cfg.num_samples, - c_procedure=self.explainer_cfg.c_procedure, + wav_norm=self.explainer_cfg.wav_norm, + candidates=self.explainer_cfg.candidates, + samples=self.explainer_cfg.samples, + c_proc=self.explainer_cfg.c_proc, pred_thres_strat=self.explainer_cfg.pred_thres_strat, - CI_threshold=self.explainer_cfg.CI_threshold, - mixing=self.explainer_cfg.mixing, + CI_thres=self.explainer_cfg.CI_thres, + mix=self.explainer_cfg.mix, pred_thres=self.explainer_cfg.pred_thres, incl_prob=self.explainer_cfg.incl_prob, top_k=self.explainer_cfg.top_k, + get_DAG=self.explainer_cfg.get_DAG, scales=self.explainer_cfg.scales, ) elif name == "SCGNN":