explaining_framework/config_gen.py

252 lines
9.2 KiB
Python
Raw Normal View History

2023-01-16 00:41:24 +00:00
import glob
import os
import shutil
import sys
from torch_geometric.data.makedirs import makedirs
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.utils.io import string_to_python
from explaining_framework.utils.io import (obj_config_to_str, read_yaml,
write_yaml)
# class BaseConfigGenerator(object):
# def __init__(self,dataset_name:str,explainer_name:str, explainer_config:str, model_folder:str):
# self.dataset_name=dataset_name
# self.explainer_name=explainer_name
# self.explainer_config = explainer_config
# self.model_folder = model_folder
# class ExplainingConfigGenerator(object):
# def __init__(
# self,
# dataset_name: str,
# explainer_name: str,
# model_folder: str,
# explainer_config: str = "default",
# item: list = None,
# ckpt: str = "best",
# ):
# self.dataset_name = dataset_name
# self.explainer_name = explainer_name
# self.explainer_config = explainer_config
# self.item = item
# self.ckpt = ckpt
# self.model_folder = model_folder
def explaining_conf(
dataset: str, model_kind: str, explainer: str, explainer_config: str = "default"
):
explaining_cfg = {}
explaining_cfg[
"cfg_dest"
] = f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
explaining_cfg["dataset"] = {}
explaining_cfg["dataset"]["name"] = dataset
explaining_cfg["explainer"] = {}
explaining_cfg["explainer"]["cfg"] = explainer_config
explaining_cfg["explainer"]["name"] = explainer
2023-01-31 09:19:17 +00:00
explaining_cfg["explainer"]["force"] = False
2023-01-16 00:41:24 +00:00
explaining_cfg["explanation_type"] = "phenomenon"
explaining_cfg["model"] = {}
explaining_cfg["model"]["ckpt"] = model_kind
explaining_cfg["model"]["path"] = sys.argv[1]
# explaining_cfg['out_dir']='./explanation'
explaining_cfg["threshold"] = {}
explaining_cfg["threshold"]["value"] = {}
explaining_cfg["threshold"]["value"]["hard"] = [0, 0.1, 0.3, 0.5, 0.7, 0.9]
explaining_cfg["threshold"]["value"]["topk"] = [2, 5, 10, 20, 50]
return explaining_cfg
def explainer_conf(explainer: str, **kwargs):
explaining_cfg = {}
if explainer == "SCGNN":
explaining_cfg["target_baseline"] = kwargs.get("target_baseline")
explaining_cfg["depth"] = "all"
explaining_cfg["score_map_norm"] = kwargs.get("score_map_norm")
explaining_cfg["interest_map_norm"] = kwargs.get("interest_map_norm")
elif explainer == "EIXGNN":
explaining_cfg["L"] = kwargs.get("L")
explaining_cfg["p"] = kwargs.get("p")
explaining_cfg["importance_sampling_strategy"] = kwargs.get(
"importance_sampling_strategy"
)
explaining_cfg["domain_similarity"] = kwargs.get("domain_similarity")
explaining_cfg["signal_similarity"] = kwargs.get("signal_similarity")
2023-01-16 16:35:03 +00:00
explaining_cfg["shapley_value_approx"] = kwargs.get("shapley_value_approx")
2023-01-16 00:41:24 +00:00
return explaining_cfg
if "__main__" == __name__:
config_folder = os.path.abspath(
os.path.join(os.path.abspath(os.path.dirname(__name__)), "configs")
)
makedirs(config_folder)
explaining_folder = os.path.join(config_folder, "explaining")
makedirs(explaining_folder)
explainer_folder = os.path.join(config_folder, "explainer")
makedirs(explainer_folder)
# TODO Make a single list for all dataset name or explaining method name, etc
DATASET = [
"CIFAR10",
"TRIANGLES",
"COLORS-3",
"REDDIT-BINARY",
"REDDIT-MULTI-5K",
"REDDIT-MULTI-12K",
"COLLAB",
"DBLP_v1",
"COIL-DEL",
"COIL-RAG",
"Fingerprint",
"Letter-high",
"Letter-low",
"Letter-med",
"MSRC_9",
"MSRC_21",
"MSRC_21C",
"DD",
"ENZYMES",
"PROTEINS",
"QM9",
"MUTAG",
"Mutagenicity",
"AIDS",
"PATTERN",
"CLUSTER",
"MNIST",
"CIFAR10",
"TSP",
"CSL",
"KarateClub",
"CS",
"Physics",
"BBBP",
"Tox21",
"HIV",
"PCBA",
"MUV",
"BACE",
"SIDER",
"ClinTox",
"AIFB",
"AM",
"MUTAG",
"BGS",
"FAUST",
"DynamicFAUST",
"ShapeNet",
"ModelNet10",
"ModelNet40",
"PascalVOC-SP",
"COCO-SP",
]
EXPLAINER = [
"CAM",
"GradCAM",
# "GNN_LRP",
"GradExplainer",
"GuidedBackPropagation",
"IntegratedGradients",
# "PGExplainer",
"PGMExplainer",
"RandomExplainer",
# "SubgraphX",
"GraphMASK",
"GNNExplainer",
"EIXGNN",
"SCGNN",
]
for dataset_name in DATASET:
for model_kind in ["best", "worst"]:
for explainer_name in EXPLAINER:
explainer_path = ["default"]
explainer_config = [None]
if explainer_name == "EIXGNN":
explainer_config = []
explainer_path = []
for imp_str in ["node", "neighborhood", "no_prior"]:
for dom_sim in ["relative_edge_density"]:
for sig_sim in ["KL", "KL_sym"]:
for sh_val in [1000]:
for L in [5, 10, 15, 20, 30, 50]:
for p in [0.2, 0.3, 0.5, 0.7]:
config = explainer_conf(
"EIXGNN",
importance_sampling_strategy=imp_str,
domain_similarity=dom_sim,
signal_similarity=sig_sim,
2023-01-16 16:35:03 +00:00
shapley_value_approx=sh_val,
2023-01-16 00:41:24 +00:00
L=L,
p=p,
)
path_explainer = os.path.join(
explainer_folder,
"EIXGNN_"
+ obj_config_to_str(config)
+ ".yaml",
)
explainer_path.append(path_explainer)
explainer_config.append(config)
write_yaml(config, path_explainer)
if explainer_name == "SCGNN":
explainer_config = []
explainer_path = []
for target_baseline in [None, "inference"]:
for depth in ["all"]:
for sc_map in [True, False]:
for in_map in [True, False]:
config = explainer_conf(
"SCGNN",
target_baseline=target_baseline,
depth=depth,
score_map_norm=sc_map,
interest_map_norm=in_map,
)
path_explainer = os.path.join(
explainer_folder,
"SCGNN_" + obj_config_to_str(config) + ".yaml",
)
explainer_path.append(path_explainer)
explainer_config.append(config)
write_yaml(config, path_explainer)
for explainer_p, explainer_c in zip(explainer_path, explainer_config):
explaining_cfg = explaining_conf(
dataset=dataset_name,
model_kind=model_kind,
explainer=explainer_name,
explainer_config=explainer_p,
)
if explainer_c is None:
PATH = os.path.join(
explaining_folder
+ "/"
+ f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
)
else:
PATH = os.path.join(
explaining_folder
+ "/"
+ f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}_{obj_config_to_str(explainer_c)}.yaml"
)
write_yaml(explaining_cfg, PATH)
2023-01-16 16:53:35 +00:00
os.makedirs(explaining_folder + "/0", exist_ok=True)
os.makedirs(explaining_folder + "/1", exist_ok=True)
2023-01-16 16:49:42 +00:00
a = glob.glob(explaining_folder + "/*.yaml")
for path in a[:8050]:
basename = os.path.basename(path)
dirname = os.path.dirname(path)
os.rename(path, dirname + "/0/" + basename)
for path in a[8050:]:
basename = os.path.basename(path)
dirname = os.path.dirname(path)
os.rename(path, dirname + "/1/" + basename)