Adding XGWT configs generation
This commit is contained in:
		
							parent
							
								
									27e8a8a4d8
								
							
						
					
					
						commit
						408bab4bc4
					
				
					 3 changed files with 135 additions and 24 deletions
				
			
		
							
								
								
									
										133
									
								
								config_gen.py
									
										
									
									
									
								
							
							
						
						
									
										133
									
								
								config_gen.py
									
										
									
									
									
								
							| 
						 | 
					@ -10,6 +10,13 @@ from torch_geometric.graphgym.utils.io import string_to_python
 | 
				
			||||||
from explaining_framework.utils.io import (obj_config_to_str, read_yaml,
 | 
					from explaining_framework.utils.io import (obj_config_to_str, read_yaml,
 | 
				
			||||||
                                           write_yaml)
 | 
					                                           write_yaml)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def chunks(lst, n):
 | 
				
			||||||
 | 
					    """Yield successive n-sized chunks from lst."""
 | 
				
			||||||
 | 
					    for i in range(0, len(lst), n):
 | 
				
			||||||
 | 
					        yield lst[i : i + n]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# class BaseConfigGenerator(object):
 | 
					# class BaseConfigGenerator(object):
 | 
				
			||||||
# def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str):
 | 
					# def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str):
 | 
				
			||||||
# self.dataset_name=dataset_name
 | 
					# self.dataset_name=dataset_name
 | 
				
			||||||
| 
						 | 
					@ -68,6 +75,7 @@ def explainer_conf(explainer: str, **kwargs):
 | 
				
			||||||
        explaining_cfg["depth"] = "all"
 | 
					        explaining_cfg["depth"] = "all"
 | 
				
			||||||
        explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm")
 | 
					        explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm")
 | 
				
			||||||
        explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm")
 | 
					        explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif explainer == "EIXGNN":
 | 
					    elif explainer == "EIXGNN":
 | 
				
			||||||
        explaining_cfg["L"] = kwargs.get("L")
 | 
					        explaining_cfg["L"] = kwargs.get("L")
 | 
				
			||||||
        explaining_cfg["p"] = kwargs.get("p")
 | 
					        explaining_cfg["p"] = kwargs.get("p")
 | 
				
			||||||
| 
						 | 
					@ -77,6 +85,23 @@ def explainer_conf(explainer: str, **kwargs):
 | 
				
			||||||
        explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity")
 | 
					        explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity")
 | 
				
			||||||
        explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity")
 | 
					        explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity")
 | 
				
			||||||
        explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx")
 | 
					        explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif explainer == "XGWT":
 | 
				
			||||||
 | 
					        explaining_cfg["wav_approx"] = kwargs.get("wav_approx")
 | 
				
			||||||
 | 
					        explaining_cfg["wav_passband"] = kwargs.get("wav_passband")
 | 
				
			||||||
 | 
					        explaining_cfg["wav_norm"] = kwargs.get("wav_norm")
 | 
				
			||||||
 | 
					        explaining_cfg["candidates"] = kwargs.get("candidates")
 | 
				
			||||||
 | 
					        explaining_cfg["samples"] = kwargs.get("samples")
 | 
				
			||||||
 | 
					        explaining_cfg["c_proc"] = kwargs.get("c_proc")
 | 
				
			||||||
 | 
					        explaining_cfg["pred_thres_strat"] = kwargs.get("pred_thres_strat")
 | 
				
			||||||
 | 
					        explaining_cfg["CI_thres"] = kwargs.get("CI_thres")
 | 
				
			||||||
 | 
					        explaining_cfg["mix"] = kwargs.get("mix")
 | 
				
			||||||
 | 
					        explaining_cfg["scales"] = kwargs.get("scales")
 | 
				
			||||||
 | 
					        explaining_cfg["pred_thres"] = kwargs.get("pred_thres")
 | 
				
			||||||
 | 
					        explaining_cfg["incl_prob"] = kwargs.get("incl_prob")
 | 
				
			||||||
 | 
					        explaining_cfg["top_k"] = kwargs.get("top_k")
 | 
				
			||||||
 | 
					        explaining_cfg["get_DAG"] = kwargs.get("get_DAG")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return explaining_cfg
 | 
					    return explaining_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -161,6 +186,7 @@ if "__main__" == __name__:
 | 
				
			||||||
        "GNNExplainer",
 | 
					        "GNNExplainer",
 | 
				
			||||||
        "EIXGNN",
 | 
					        "EIXGNN",
 | 
				
			||||||
        "SCGNN",
 | 
					        "SCGNN",
 | 
				
			||||||
 | 
					        "XGWT",
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for dataset_name in DATASET:
 | 
					    for dataset_name in DATASET:
 | 
				
			||||||
| 
						 | 
					@ -193,6 +219,8 @@ if "__main__" == __name__:
 | 
				
			||||||
                                                + ".yaml",
 | 
					                                                + ".yaml",
 | 
				
			||||||
                                            )
 | 
					                                            )
 | 
				
			||||||
                                            explainer_path.append(path_explainer)
 | 
					                                            explainer_path.append(path_explainer)
 | 
				
			||||||
 | 
					                                            if os.path.exists(path_explainer):
 | 
				
			||||||
 | 
					                                                continue
 | 
				
			||||||
                                            explainer_config.append(config)
 | 
					                                            explainer_config.append(config)
 | 
				
			||||||
                                            write_yaml(config, path_explainer)
 | 
					                                            write_yaml(config, path_explainer)
 | 
				
			||||||
                if explainer_name == "SCGNN":
 | 
					                if explainer_name == "SCGNN":
 | 
				
			||||||
| 
						 | 
					@ -215,7 +243,89 @@ if "__main__" == __name__:
 | 
				
			||||||
                                    )
 | 
					                                    )
 | 
				
			||||||
                                    explainer_path.append(path_explainer)
 | 
					                                    explainer_path.append(path_explainer)
 | 
				
			||||||
                                    explainer_config.append(config)
 | 
					                                    explainer_config.append(config)
 | 
				
			||||||
 | 
					                                    if os.path.exists(path_explainer):
 | 
				
			||||||
 | 
					                                        continue
 | 
				
			||||||
                                    write_yaml(config, path_explainer)
 | 
					                                    write_yaml(config, path_explainer)
 | 
				
			||||||
 | 
					                if explainer_name == "XGWT":
 | 
				
			||||||
 | 
					                    explainer_config = []
 | 
				
			||||||
 | 
					                    explainer_path = []
 | 
				
			||||||
 | 
					                    for wav_approx in [False]:
 | 
				
			||||||
 | 
					                        for wav_passband in ["heat"]:
 | 
				
			||||||
 | 
					                            for wav_norm in [True]:
 | 
				
			||||||
 | 
					                                for candidates in [10, 15, 30, 50]:
 | 
				
			||||||
 | 
					                                    for samples in [10, 25, 50]:
 | 
				
			||||||
 | 
					                                        for c_proc in ["auto"]:
 | 
				
			||||||
 | 
					                                            for pred_thres_strat in ["regular"]:
 | 
				
			||||||
 | 
					                                                for CI_thres in [0.05]:
 | 
				
			||||||
 | 
					                                                    for mix in ["uniform"]:
 | 
				
			||||||
 | 
					                                                        for scales in [
 | 
				
			||||||
 | 
					                                                            [2],
 | 
				
			||||||
 | 
					                                                            [3],
 | 
				
			||||||
 | 
					                                                            [5],
 | 
				
			||||||
 | 
					                                                            [9],
 | 
				
			||||||
 | 
					                                                            [2, 3, 5],
 | 
				
			||||||
 | 
					                                                            [2, 3, 5, 9],
 | 
				
			||||||
 | 
					                                                            [5, 9],
 | 
				
			||||||
 | 
					                                                            [2, 3],
 | 
				
			||||||
 | 
					                                                        ]:
 | 
				
			||||||
 | 
					                                                            for pred_thres in [
 | 
				
			||||||
 | 
					                                                                0.1,
 | 
				
			||||||
 | 
					                                                                0.25,
 | 
				
			||||||
 | 
					                                                                0.5,
 | 
				
			||||||
 | 
					                                                            ]:
 | 
				
			||||||
 | 
					                                                                for incl_prob in [
 | 
				
			||||||
 | 
					                                                                    0.2,
 | 
				
			||||||
 | 
					                                                                    0.4,
 | 
				
			||||||
 | 
					                                                                    0.6,
 | 
				
			||||||
 | 
					                                                                ]:
 | 
				
			||||||
 | 
					                                                                    for top_k in [
 | 
				
			||||||
 | 
					                                                                        2,
 | 
				
			||||||
 | 
					                                                                        5,
 | 
				
			||||||
 | 
					                                                                        10,
 | 
				
			||||||
 | 
					                                                                    ]:
 | 
				
			||||||
 | 
					                                                                        for get_DAG in [
 | 
				
			||||||
 | 
					                                                                            False
 | 
				
			||||||
 | 
					                                                                        ]:
 | 
				
			||||||
 | 
					                                                                            config = explainer_conf(
 | 
				
			||||||
 | 
					                                                                                "XGWT",
 | 
				
			||||||
 | 
					                                                                                wav_approx=wav_approx,
 | 
				
			||||||
 | 
					                                                                                wav_passband=wav_passband,
 | 
				
			||||||
 | 
					                                                                                wav_norm=wav_norm,
 | 
				
			||||||
 | 
					                                                                                candidates=candidates,
 | 
				
			||||||
 | 
					                                                                                samples=samples,
 | 
				
			||||||
 | 
					                                                                                c_proc=c_proc,
 | 
				
			||||||
 | 
					                                                                                pred_thres_strat=pred_thres_strat,
 | 
				
			||||||
 | 
					                                                                                CI_thres=CI_thres,
 | 
				
			||||||
 | 
					                                                                                mix=mix,
 | 
				
			||||||
 | 
					                                                                                scales=scales,
 | 
				
			||||||
 | 
					                                                                                pred_thres=pred_thres,
 | 
				
			||||||
 | 
					                                                                                incl_prob=incl_prob,
 | 
				
			||||||
 | 
					                                                                                top_k=top_k,
 | 
				
			||||||
 | 
					                                                                                get_DAG=get_DAG,
 | 
				
			||||||
 | 
					                                                                            )
 | 
				
			||||||
 | 
					                                                                            path_explainer = os.path.join(
 | 
				
			||||||
 | 
					                                                                                explainer_folder,
 | 
				
			||||||
 | 
					                                                                                "XGWT_"
 | 
				
			||||||
 | 
					                                                                                + obj_config_to_str(
 | 
				
			||||||
 | 
					                                                                                    config
 | 
				
			||||||
 | 
					                                                                                )
 | 
				
			||||||
 | 
					                                                                                + ".yaml",
 | 
				
			||||||
 | 
					                                                                            )
 | 
				
			||||||
 | 
					                                                                            explainer_path.append(
 | 
				
			||||||
 | 
					                                                                                path_explainer
 | 
				
			||||||
 | 
					                                                                            )
 | 
				
			||||||
 | 
					                                                                            explainer_config.append(
 | 
				
			||||||
 | 
					                                                                                config
 | 
				
			||||||
 | 
					                                                                            )
 | 
				
			||||||
 | 
					                                                                            if os.path.exists(
 | 
				
			||||||
 | 
					                                                                                path_explainer
 | 
				
			||||||
 | 
					                                                                            ):
 | 
				
			||||||
 | 
					                                                                                continue
 | 
				
			||||||
 | 
					                                                                            write_yaml(
 | 
				
			||||||
 | 
					                                                                                config,
 | 
				
			||||||
 | 
					                                                                                path_explainer,
 | 
				
			||||||
 | 
					                                                                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for explainer_p, explainer_c in zip(explainer_path, explainer_config):
 | 
					                for explainer_p, explainer_c in zip(explainer_path, explainer_config):
 | 
				
			||||||
                    explaining_cfg = explaining_conf(
 | 
					                    explaining_cfg = explaining_conf(
 | 
				
			||||||
                        dataset=dataset_name,
 | 
					                        dataset=dataset_name,
 | 
				
			||||||
| 
						 | 
					@ -236,16 +346,15 @@ if "__main__" == __name__:
 | 
				
			||||||
                            + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml"
 | 
					                            + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml"
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
                    write_yaml(explaining_cfg, PATH)
 | 
					                    write_yaml(explaining_cfg, PATH)
 | 
				
			||||||
    os.makedirs(explaining_folder + "/0", exist_ok=True)
 | 
					    a = sorted(glob.glob(explaining_folder + "/*.yaml"))
 | 
				
			||||||
    os.makedirs(explaining_folder + "/1", exist_ok=True)
 | 
					 | 
				
			||||||
    a = glob.glob(explaining_folder + "/*.yaml")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for path in a[:8050]:
 | 
					    num_GPU = 4
 | 
				
			||||||
        basename = os.path.basename(path)
 | 
					    for i in range(num_GPU):
 | 
				
			||||||
        dirname = os.path.dirname(path)
 | 
					        os.makedirs(explaining_folder + f"/{i}", exist_ok=True)
 | 
				
			||||||
        os.rename(path, dirname + "/0/" + basename)
 | 
					    split_size = int(len(a) / num_GPU)
 | 
				
			||||||
 | 
					    data = chunks(a, split_size)
 | 
				
			||||||
    for path in a[8050:]:
 | 
					    for i, d in enumerate(data):
 | 
				
			||||||
        basename = os.path.basename(path)
 | 
					        for path in d:
 | 
				
			||||||
        dirname = os.path.dirname(path)
 | 
					            basename = os.path.basename(path)
 | 
				
			||||||
        os.rename(path, dirname + "/1/" + basename)
 | 
					            dirname = os.path.dirname(path)
 | 
				
			||||||
 | 
					            os.rename(path, dirname + f"/{i}/" + basename)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,17 +38,18 @@ def set_xgwt_cfg(xgwt_cfg):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    xgwt_cfg.wav_approx = False
 | 
					    xgwt_cfg.wav_approx = False
 | 
				
			||||||
    xgwt_cfg.wav_passband = "heat"
 | 
					    xgwt_cfg.wav_passband = "heat"
 | 
				
			||||||
    xgwt_cfg.wav_normalization = True
 | 
					    xgwt_cfg.wav_norm = True
 | 
				
			||||||
    xgwt_cfg.num_candidates = 30
 | 
					    xgwt_cfg.candidates = 30
 | 
				
			||||||
    xgwt_cfg.num_samples = 10
 | 
					    xgwt_cfg.samples = 10
 | 
				
			||||||
    xgwt_cfg.c_procedure = "auto"
 | 
					    xgwt_cfg.c_proc = "auto"
 | 
				
			||||||
    xgwt_cfg.pred_thres_strat = "regular"
 | 
					    xgwt_cfg.pred_thres_strat = "regular"
 | 
				
			||||||
    xgwt_cfg.CI_threshold = 0.05
 | 
					    xgwt_cfg.CI_thres = 0.05
 | 
				
			||||||
    xgwt_cfg.mixing = "uniform"
 | 
					    xgwt_cfg.mix = "uniform"
 | 
				
			||||||
    xgwt_cfg.scales = [3]
 | 
					    xgwt_cfg.scales = [3]
 | 
				
			||||||
    xgwt_cfg.pred_thres = 0.1
 | 
					    xgwt_cfg.pred_thres = 0.1
 | 
				
			||||||
    xgwt_cfg.incl_prob = 0.4
 | 
					    xgwt_cfg.incl_prob = 0.4
 | 
				
			||||||
    xgwt_cfg.top_k = 5
 | 
					    xgwt_cfg.top_k = 5
 | 
				
			||||||
 | 
					    xgwt_cfg.get_DAG = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def assert_cfg(xgwt_cfg):
 | 
					def assert_cfg(xgwt_cfg):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -300,16 +300,17 @@ class ExplainingOutline(object):
 | 
				
			||||||
                explaining_algorithm = XGWT(
 | 
					                explaining_algorithm = XGWT(
 | 
				
			||||||
                    wav_approx=self.explainer_cfg.wav_approx,
 | 
					                    wav_approx=self.explainer_cfg.wav_approx,
 | 
				
			||||||
                    wav_passband=self.explainer_cfg.wav_passband,
 | 
					                    wav_passband=self.explainer_cfg.wav_passband,
 | 
				
			||||||
                    wav_normalization=self.explainer_cfg.wav_normalization,
 | 
					                    wav_norm=self.explainer_cfg.wav_norm,
 | 
				
			||||||
                    num_candidates=self.explainer_cfg.num_candidates,
 | 
					                    candidates=self.explainer_cfg.candidates,
 | 
				
			||||||
                    num_samples=self.explainer_cfg.num_samples,
 | 
					                    samples=self.explainer_cfg.samples,
 | 
				
			||||||
                    c_procedure=self.explainer_cfg.c_procedure,
 | 
					                    c_proc=self.explainer_cfg.c_proc,
 | 
				
			||||||
                    pred_thres_strat=self.explainer_cfg.pred_thres_strat,
 | 
					                    pred_thres_strat=self.explainer_cfg.pred_thres_strat,
 | 
				
			||||||
                    CI_threshold=self.explainer_cfg.CI_threshold,
 | 
					                    CI_thres=self.explainer_cfg.CI_thres,
 | 
				
			||||||
                    mixing=self.explainer_cfg.mixing,
 | 
					                    mix=self.explainer_cfg.mix,
 | 
				
			||||||
                    pred_thres=self.explainer_cfg.pred_thres,
 | 
					                    pred_thres=self.explainer_cfg.pred_thres,
 | 
				
			||||||
                    incl_prob=self.explainer_cfg.incl_prob,
 | 
					                    incl_prob=self.explainer_cfg.incl_prob,
 | 
				
			||||||
                    top_k=self.explainer_cfg.top_k,
 | 
					                    top_k=self.explainer_cfg.top_k,
 | 
				
			||||||
 | 
					                    get_DAG=self.explainer_cfg.get_DAG,
 | 
				
			||||||
                    scales=self.explainer_cfg.scales,
 | 
					                    scales=self.explainer_cfg.scales,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            elif name == "SCGNN":
 | 
					            elif name == "SCGNN":
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue