Adding XGWT configs generation
This commit is contained in:
parent
27e8a8a4d8
commit
408bab4bc4
133
config_gen.py
133
config_gen.py
|
@ -10,6 +10,13 @@ from torch_geometric.graphgym.utils.io import string_to_python
|
|||
from explaining_framework.utils.io import (obj_config_to_str, read_yaml,
|
||||
write_yaml)
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
# class BaseConfigGenerator(object):
|
||||
# def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str):
|
||||
# self.dataset_name=dataset_name
|
||||
|
@ -68,6 +75,7 @@ def explainer_conf(explainer: str, **kwargs):
|
|||
explaining_cfg["depth"] = "all"
|
||||
explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm")
|
||||
explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm")
|
||||
|
||||
elif explainer == "EIXGNN":
|
||||
explaining_cfg["L"] = kwargs.get("L")
|
||||
explaining_cfg["p"] = kwargs.get("p")
|
||||
|
@ -77,6 +85,23 @@ def explainer_conf(explainer: str, **kwargs):
|
|||
explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity")
|
||||
explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity")
|
||||
explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx")
|
||||
|
||||
elif explainer == "XGWT":
|
||||
explaining_cfg["wav_approx"] = kwargs.get("wav_approx")
|
||||
explaining_cfg["wav_passband"] = kwargs.get("wav_passband")
|
||||
explaining_cfg["wav_norm"] = kwargs.get("wav_norm")
|
||||
explaining_cfg["candidates"] = kwargs.get("candidates")
|
||||
explaining_cfg["samples"] = kwargs.get("samples")
|
||||
explaining_cfg["c_proc"] = kwargs.get("c_proc")
|
||||
explaining_cfg["pred_thres_strat"] = kwargs.get("pred_thres_strat")
|
||||
explaining_cfg["CI_thres"] = kwargs.get("CI_thres")
|
||||
explaining_cfg["mix"] = kwargs.get("mix")
|
||||
explaining_cfg["scales"] = kwargs.get("scales")
|
||||
explaining_cfg["pred_thres"] = kwargs.get("pred_thres")
|
||||
explaining_cfg["incl_prob"] = kwargs.get("incl_prob")
|
||||
explaining_cfg["top_k"] = kwargs.get("top_k")
|
||||
explaining_cfg["get_DAG"] = kwargs.get("get_DAG")
|
||||
|
||||
return explaining_cfg
|
||||
|
||||
|
||||
|
@ -161,6 +186,7 @@ if "__main__" == __name__:
|
|||
"GNNExplainer",
|
||||
"EIXGNN",
|
||||
"SCGNN",
|
||||
"XGWT",
|
||||
]
|
||||
|
||||
for dataset_name in DATASET:
|
||||
|
@ -193,6 +219,8 @@ if "__main__" == __name__:
|
|||
+ ".yaml",
|
||||
)
|
||||
explainer_path.append(path_explainer)
|
||||
if os.path.exists(path_explainer):
|
||||
continue
|
||||
explainer_config.append(config)
|
||||
write_yaml(config, path_explainer)
|
||||
if explainer_name == "SCGNN":
|
||||
|
@ -215,7 +243,89 @@ if "__main__" == __name__:
|
|||
)
|
||||
explainer_path.append(path_explainer)
|
||||
explainer_config.append(config)
|
||||
if os.path.exists(path_explainer):
|
||||
continue
|
||||
write_yaml(config, path_explainer)
|
||||
if explainer_name == "XGWT":
|
||||
explainer_config = []
|
||||
explainer_path = []
|
||||
for wav_approx in [False]:
|
||||
for wav_passband in ["heat"]:
|
||||
for wav_norm in [True]:
|
||||
for candidates in [10, 15, 30, 50]:
|
||||
for samples in [10, 25, 50]:
|
||||
for c_proc in ["auto"]:
|
||||
for pred_thres_strat in ["regular"]:
|
||||
for CI_thres in [0.05]:
|
||||
for mix in ["uniform"]:
|
||||
for scales in [
|
||||
[2],
|
||||
[3],
|
||||
[5],
|
||||
[9],
|
||||
[2, 3, 5],
|
||||
[2, 3, 5, 9],
|
||||
[5, 9],
|
||||
[2, 3],
|
||||
]:
|
||||
for pred_thres in [
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
]:
|
||||
for incl_prob in [
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
]:
|
||||
for top_k in [
|
||||
2,
|
||||
5,
|
||||
10,
|
||||
]:
|
||||
for get_DAG in [
|
||||
False
|
||||
]:
|
||||
config = explainer_conf(
|
||||
"XGWT",
|
||||
wav_approx=wav_approx,
|
||||
wav_passband=wav_passband,
|
||||
wav_norm=wav_norm,
|
||||
candidates=candidates,
|
||||
samples=samples,
|
||||
c_proc=c_proc,
|
||||
pred_thres_strat=pred_thres_strat,
|
||||
CI_thres=CI_thres,
|
||||
mix=mix,
|
||||
scales=scales,
|
||||
pred_thres=pred_thres,
|
||||
incl_prob=incl_prob,
|
||||
top_k=top_k,
|
||||
get_DAG=get_DAG,
|
||||
)
|
||||
path_explainer = os.path.join(
|
||||
explainer_folder,
|
||||
"XGWT_"
|
||||
+ obj_config_to_str(
|
||||
config
|
||||
)
|
||||
+ ".yaml",
|
||||
)
|
||||
explainer_path.append(
|
||||
path_explainer
|
||||
)
|
||||
explainer_config.append(
|
||||
config
|
||||
)
|
||||
if os.path.exists(
|
||||
path_explainer
|
||||
):
|
||||
continue
|
||||
write_yaml(
|
||||
config,
|
||||
path_explainer,
|
||||
)
|
||||
|
||||
for explainer_p, explainer_c in zip(explainer_path, explainer_config):
|
||||
explaining_cfg = explaining_conf(
|
||||
dataset=dataset_name,
|
||||
|
@ -236,16 +346,15 @@ if "__main__" == __name__:
|
|||
+ f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml"
|
||||
)
|
||||
write_yaml(explaining_cfg, PATH)
|
||||
os.makedirs(explaining_folder + "/0", exist_ok=True)
|
||||
os.makedirs(explaining_folder + "/1", exist_ok=True)
|
||||
a = glob.glob(explaining_folder + "/*.yaml")
|
||||
a = sorted(glob.glob(explaining_folder + "/*.yaml"))
|
||||
|
||||
for path in a[:8050]:
|
||||
basename = os.path.basename(path)
|
||||
dirname = os.path.dirname(path)
|
||||
os.rename(path, dirname + "/0/" + basename)
|
||||
|
||||
for path in a[8050:]:
|
||||
basename = os.path.basename(path)
|
||||
dirname = os.path.dirname(path)
|
||||
os.rename(path, dirname + "/1/" + basename)
|
||||
num_GPU = 4
|
||||
for i in range(num_GPU):
|
||||
os.makedirs(explaining_folder + f"/{i}", exist_ok=True)
|
||||
split_size = int(len(a) / num_GPU)
|
||||
data = chunks(a, split_size)
|
||||
for i, d in enumerate(data):
|
||||
for path in d:
|
||||
basename = os.path.basename(path)
|
||||
dirname = os.path.dirname(path)
|
||||
os.rename(path, dirname + f"/{i}/" + basename)
|
||||
|
|
|
@ -38,17 +38,18 @@ def set_xgwt_cfg(xgwt_cfg):
|
|||
|
||||
xgwt_cfg.wav_approx = False
|
||||
xgwt_cfg.wav_passband = "heat"
|
||||
xgwt_cfg.wav_normalization = True
|
||||
xgwt_cfg.num_candidates = 30
|
||||
xgwt_cfg.num_samples = 10
|
||||
xgwt_cfg.c_procedure = "auto"
|
||||
xgwt_cfg.wav_norm = True
|
||||
xgwt_cfg.candidates = 30
|
||||
xgwt_cfg.samples = 10
|
||||
xgwt_cfg.c_proc = "auto"
|
||||
xgwt_cfg.pred_thres_strat = "regular"
|
||||
xgwt_cfg.CI_threshold = 0.05
|
||||
xgwt_cfg.mixing = "uniform"
|
||||
xgwt_cfg.CI_thres = 0.05
|
||||
xgwt_cfg.mix = "uniform"
|
||||
xgwt_cfg.scales = [3]
|
||||
xgwt_cfg.pred_thres = 0.1
|
||||
xgwt_cfg.incl_prob = 0.4
|
||||
xgwt_cfg.top_k = 5
|
||||
xgwt_cfg.get_DAG = False
|
||||
|
||||
|
||||
def assert_cfg(xgwt_cfg):
|
||||
|
|
|
@ -300,16 +300,17 @@ class ExplainingOutline(object):
|
|||
explaining_algorithm = XGWT(
|
||||
wav_approx=self.explainer_cfg.wav_approx,
|
||||
wav_passband=self.explainer_cfg.wav_passband,
|
||||
wav_normalization=self.explainer_cfg.wav_normalization,
|
||||
num_candidates=self.explainer_cfg.num_candidates,
|
||||
num_samples=self.explainer_cfg.num_samples,
|
||||
c_procedure=self.explainer_cfg.c_procedure,
|
||||
wav_norm=self.explainer_cfg.wav_norm,
|
||||
candidates=self.explainer_cfg.candidates,
|
||||
samples=self.explainer_cfg.samples,
|
||||
c_proc=self.explainer_cfg.c_proc,
|
||||
pred_thres_strat=self.explainer_cfg.pred_thres_strat,
|
||||
CI_threshold=self.explainer_cfg.CI_threshold,
|
||||
mixing=self.explainer_cfg.mixing,
|
||||
CI_thres=self.explainer_cfg.CI_thres,
|
||||
mix=self.explainer_cfg.mix,
|
||||
pred_thres=self.explainer_cfg.pred_thres,
|
||||
incl_prob=self.explainer_cfg.incl_prob,
|
||||
top_k=self.explainer_cfg.top_k,
|
||||
get_DAG=self.explainer_cfg.get_DAG,
|
||||
scales=self.explainer_cfg.scales,
|
||||
)
|
||||
elif name == "SCGNN":
|
||||
|
|
Loading…
Reference in New Issue