361 lines
17 KiB
Python
361 lines
17 KiB
Python
import glob
|
|
import os
|
|
import shutil
|
|
import sys
|
|
|
|
from torch_geometric.data.makedirs import makedirs
|
|
from torch_geometric.graphgym.loader import create_dataset
|
|
from torch_geometric.graphgym.utils.io import string_to_python
|
|
|
|
from explaining_framework.utils.io import (obj_config_to_str, read_yaml,
|
|
write_yaml)
|
|
|
|
|
|
def chunks(lst, n):
|
|
"""Yield successive n-sized chunks from lst."""
|
|
for i in range(0, len(lst), n):
|
|
yield lst[i : i + n]
|
|
|
|
|
|
# class BaseConfigGenerator(object):
|
|
# def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str):
|
|
# self.dataset_name=dataset_name
|
|
# self.explainer_name=explainer_name
|
|
# self.explainer_config = explainer_config
|
|
# self.model_folder = model_folder
|
|
|
|
|
|
# class ExplainingConfigGenerator(object):
|
|
# def __init__(
|
|
# self,
|
|
# dataset_name: str,
|
|
# explainer_name: str,
|
|
# model_folder: str,
|
|
# explainer_config: str = "default",
|
|
# item: list = None,
|
|
# ckpt: str = "best",
|
|
# ):
|
|
# self.dataset_name = dataset_name
|
|
# self.explainer_name = explainer_name
|
|
# self.explainer_config = explainer_config
|
|
# self.item = item
|
|
# self.ckpt = ckpt
|
|
# self.model_folder = model_folder
|
|
|
|
|
|
def explaining_conf(
|
|
dataset: str, model_kind: str, explainer: str, explainer_config: str = "default"
|
|
):
|
|
explaining_cfg = {}
|
|
explaining_cfg[
|
|
"cfg_dest"
|
|
] = f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
|
|
explaining_cfg["dataset"] = {}
|
|
explaining_cfg["dataset"]["name"] = dataset
|
|
explaining_cfg["explainer"] = {}
|
|
explaining_cfg["explainer"]["cfg"] = explainer_config
|
|
explaining_cfg["explainer"]["name"] = explainer
|
|
explaining_cfg["explainer"]["force"] = False
|
|
explaining_cfg["explanation_type"] = "phenomenon"
|
|
explaining_cfg["model"] = {}
|
|
explaining_cfg["model"]["ckpt"] = model_kind
|
|
explaining_cfg["model"]["path"] = sys.argv[1]
|
|
# explaining_cfg['out_dir']='./explanation'
|
|
explaining_cfg["threshold"] = {}
|
|
explaining_cfg["threshold"]["value"] = {}
|
|
explaining_cfg["threshold"]["value"]["hard"] = [0, 0.1, 0.3, 0.5, 0.7, 0.9]
|
|
explaining_cfg["threshold"]["value"]["topk"] = [2, 5, 10, 20, 50]
|
|
return explaining_cfg
|
|
|
|
|
|
def explainer_conf(explainer: str, **kwargs):
|
|
explaining_cfg = {}
|
|
if explainer == "SCGNN":
|
|
explaining_cfg["target_baseline"] = kwargs.get("target_baseline")
|
|
explaining_cfg["depth"] = "all"
|
|
explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm")
|
|
explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm")
|
|
|
|
elif explainer == "EIXGNN":
|
|
explaining_cfg["L"] = kwargs.get("L")
|
|
explaining_cfg["p"] = kwargs.get("p")
|
|
explaining_cfg["importance_sampling_strategy"] = kwargs.get(
|
|
"importance_sampling_strategy"
|
|
)
|
|
explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity")
|
|
explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity")
|
|
explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx")
|
|
|
|
elif explainer == "XGWT":
|
|
explaining_cfg["wav_approx"] = kwargs.get("wav_approx")
|
|
explaining_cfg["wav_passband"] = kwargs.get("wav_passband")
|
|
explaining_cfg["wav_norm"] = kwargs.get("wav_norm")
|
|
explaining_cfg["candidates"] = kwargs.get("candidates")
|
|
explaining_cfg["samples"] = kwargs.get("samples")
|
|
explaining_cfg["c_proc"] = kwargs.get("c_proc")
|
|
explaining_cfg["pred_thres_strat"] = kwargs.get("pred_thres_strat")
|
|
explaining_cfg["CI_thres"] = kwargs.get("CI_thres")
|
|
explaining_cfg["mix"] = kwargs.get("mix")
|
|
explaining_cfg["scales"] = kwargs.get("scales")
|
|
explaining_cfg["pred_thres"] = kwargs.get("pred_thres")
|
|
explaining_cfg["incl_prob"] = kwargs.get("incl_prob")
|
|
explaining_cfg["top_k"] = kwargs.get("top_k")
|
|
explaining_cfg["get_DAG"] = kwargs.get("get_DAG")
|
|
|
|
return explaining_cfg
|
|
|
|
|
|
if "__main__" == __name__:
|
|
config_folder = os.path.abspath(
|
|
os.path.join(os.path.abspath(os.path.dirname(__name__)), "configs")
|
|
)
|
|
makedirs(config_folder)
|
|
explaining_folder = os.path.join(config_folder, "explaining")
|
|
makedirs(explaining_folder)
|
|
explainer_folder = os.path.join(config_folder, "explainer")
|
|
makedirs(explainer_folder)
|
|
|
|
# TODO Make a single list for all dataset name or explaining method name, etc
|
|
|
|
DATASET = [
|
|
"CIFAR10",
|
|
"TRIANGLES",
|
|
"COLORS-3",
|
|
"REDDIT-BINARY",
|
|
"REDDIT-MULTI-5K",
|
|
"REDDIT-MULTI-12K",
|
|
"COLLAB",
|
|
"DBLP_v1",
|
|
"COIL-DEL",
|
|
"COIL-RAG",
|
|
"Fingerprint",
|
|
"Letter-high",
|
|
"Letter-low",
|
|
"Letter-med",
|
|
"MSRC_9",
|
|
"MSRC_21",
|
|
"MSRC_21C",
|
|
"DD",
|
|
"ENZYMES",
|
|
"PROTEINS",
|
|
"QM9",
|
|
"MUTAG",
|
|
"Mutagenicity",
|
|
"AIDS",
|
|
"PATTERN",
|
|
"CLUSTER",
|
|
"MNIST",
|
|
"CIFAR10",
|
|
"TSP",
|
|
"CSL",
|
|
"KarateClub",
|
|
"CS",
|
|
"Physics",
|
|
"BBBP",
|
|
"Tox21",
|
|
"HIV",
|
|
"PCBA",
|
|
"MUV",
|
|
"BACE",
|
|
"SIDER",
|
|
"ClinTox",
|
|
"AIFB",
|
|
"AM",
|
|
"MUTAG",
|
|
"BGS",
|
|
"FAUST",
|
|
"DynamicFAUST",
|
|
"ShapeNet",
|
|
"ModelNet10",
|
|
"ModelNet40",
|
|
"PascalVOC-SP",
|
|
"COCO-SP",
|
|
]
|
|
EXPLAINER = [
|
|
"CAM",
|
|
"GradCAM",
|
|
# "GNN_LRP",
|
|
"GradExplainer",
|
|
"GuidedBackPropagation",
|
|
"IntegratedGradients",
|
|
# "PGExplainer",
|
|
"PGMExplainer",
|
|
"RandomExplainer",
|
|
# "SubgraphX",
|
|
"GraphMASK",
|
|
"GNNExplainer",
|
|
"EIXGNN",
|
|
"SCGNN",
|
|
"XGWT",
|
|
]
|
|
|
|
for dataset_name in DATASET:
|
|
for model_kind in ["best", "worst"]:
|
|
for explainer_name in EXPLAINER:
|
|
explainer_path = ["default"]
|
|
explainer_config = [None]
|
|
if explainer_name == "EIXGNN":
|
|
explainer_config = []
|
|
explainer_path = []
|
|
for imp_str in ["node", "neighborhood", "no_prior"]:
|
|
for dom_sim in ["relative_edge_density"]:
|
|
for sig_sim in ["KL", "KL_sym"]:
|
|
for sh_val in [1000]:
|
|
for L in [5, 10, 15, 20, 30, 50]:
|
|
for p in [0.2, 0.3, 0.5, 0.7]:
|
|
config = explainer_conf(
|
|
"EIXGNN",
|
|
importance_sampling_strategy=imp_str,
|
|
domain_similarity=dom_sim,
|
|
signal_similarity=sig_sim,
|
|
shapley_value_approx=sh_val,
|
|
L=L,
|
|
p=p,
|
|
)
|
|
path_explainer = os.path.join(
|
|
explainer_folder,
|
|
"EIXGNN_"
|
|
+ obj_config_to_str(config)
|
|
+ ".yaml",
|
|
)
|
|
explainer_path.append(path_explainer)
|
|
if os.path.exists(path_explainer):
|
|
continue
|
|
explainer_config.append(config)
|
|
write_yaml(config, path_explainer)
|
|
if explainer_name == "SCGNN":
|
|
explainer_config = []
|
|
explainer_path = []
|
|
for target_baseline in [None, "inference"]:
|
|
for depth in ["all"]:
|
|
for sc_map in [True, False]:
|
|
for in_map in [True, False]:
|
|
config = explainer_conf(
|
|
"SCGNN",
|
|
target_baseline=target_baseline,
|
|
depth=depth,
|
|
score_map_norm=sc_map,
|
|
interest_map_norm=in_map,
|
|
)
|
|
path_explainer = os.path.join(
|
|
explainer_folder,
|
|
"SCGNN_" + obj_config_to_str(config) + ".yaml",
|
|
)
|
|
explainer_path.append(path_explainer)
|
|
explainer_config.append(config)
|
|
if os.path.exists(path_explainer):
|
|
continue
|
|
write_yaml(config, path_explainer)
|
|
if explainer_name == "XGWT":
|
|
explainer_config = []
|
|
explainer_path = []
|
|
for wav_approx in [False]:
|
|
for wav_passband in ["heat"]:
|
|
for wav_norm in [True]:
|
|
for candidates in [10, 15, 30, 50]:
|
|
for samples in [10, 25, 50]:
|
|
for c_proc in ["auto"]:
|
|
for pred_thres_strat in ["regular"]:
|
|
for CI_thres in [0.05]:
|
|
for mix in ["uniform"]:
|
|
for scales in [
|
|
[2],
|
|
[3],
|
|
[5],
|
|
[9],
|
|
[2, 3, 5],
|
|
[2, 3, 5, 9],
|
|
[5, 9],
|
|
[2, 3],
|
|
]:
|
|
for pred_thres in [
|
|
0.1,
|
|
0.25,
|
|
0.5,
|
|
]:
|
|
for incl_prob in [
|
|
0.2,
|
|
0.4,
|
|
0.6,
|
|
]:
|
|
for top_k in [
|
|
2,
|
|
5,
|
|
10,
|
|
]:
|
|
for get_DAG in [
|
|
False
|
|
]:
|
|
config = explainer_conf(
|
|
"XGWT",
|
|
wav_approx=wav_approx,
|
|
wav_passband=wav_passband,
|
|
wav_norm=wav_norm,
|
|
candidates=candidates,
|
|
samples=samples,
|
|
c_proc=c_proc,
|
|
pred_thres_strat=pred_thres_strat,
|
|
CI_thres=CI_thres,
|
|
mix=mix,
|
|
scales=scales,
|
|
pred_thres=pred_thres,
|
|
incl_prob=incl_prob,
|
|
top_k=top_k,
|
|
get_DAG=get_DAG,
|
|
)
|
|
path_explainer = os.path.join(
|
|
explainer_folder,
|
|
"XGWT_"
|
|
+ obj_config_to_str(
|
|
config
|
|
)
|
|
+ ".yaml",
|
|
)
|
|
explainer_path.append(
|
|
path_explainer
|
|
)
|
|
explainer_config.append(
|
|
config
|
|
)
|
|
if os.path.exists(
|
|
path_explainer
|
|
):
|
|
continue
|
|
write_yaml(
|
|
config,
|
|
path_explainer,
|
|
)
|
|
|
|
for explainer_p, explainer_c in zip(explainer_path, explainer_config):
|
|
explaining_cfg = explaining_conf(
|
|
dataset=dataset_name,
|
|
model_kind=model_kind,
|
|
explainer=explainer_name,
|
|
explainer_config=explainer_p,
|
|
)
|
|
if explainer_c is None:
|
|
PATH = os.path.join(
|
|
explaining_folder
|
|
+ "/"
|
|
+ f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
|
|
)
|
|
else:
|
|
PATH = os.path.join(
|
|
explaining_folder
|
|
+ "/"
|
|
+ f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml"
|
|
)
|
|
write_yaml(explaining_cfg, PATH)
|
|
a = sorted(glob.glob(explaining_folder + "/*.yaml"))
|
|
|
|
num_GPU = 4
|
|
for i in range(num_GPU):
|
|
os.makedirs(explaining_folder + f"/{i}", exist_ok=True)
|
|
split_size = int(len(a) / num_GPU)
|
|
data = chunks(a, split_size)
|
|
for i, d in enumerate(data):
|
|
for path in d:
|
|
basename = os.path.basename(path)
|
|
dirname = os.path.dirname(path)
|
|
os.rename(path, dirname + f"/{i}/" + basename)
|