229 lines
7.7 KiB
Python
229 lines
7.7 KiB
Python
import copy
|
|
from typing import Any
|
|
|
|
from eixgnn.eixgnn import EiXGNN
|
|
from scgnn.scgnn import SCGNN
|
|
from torch_geometric.data import Batch, Data
|
|
from torch_geometric.explain import Explainer
|
|
from torch_geometric.graphgym.config import cfg
|
|
from torch_geometric.graphgym.loader import create_dataset
|
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
|
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
|
eixgnn_cfg
|
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
|
from explaining_framework.config.explaining_config import explaining_cfg
|
|
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
|
from explaining_framework.explainers.wrappers.from_graphxai import \
|
|
GraphXAIWrapper
|
|
from explaining_framework.metric.accuracy import Accuracy
|
|
from explaining_framework.metric.fidelity import Fidelity
|
|
from explaining_framework.metric.robust import Attack
|
|
from explaining_framework.metric.sparsity import Sparsity
|
|
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
|
_load_ckpt)
|
|
|
|
all__captum = [
|
|
"LRP",
|
|
"DeepLift",
|
|
"DeepLiftShap",
|
|
"FeatureAblation",
|
|
"FeaturePermutation",
|
|
"GradientShap",
|
|
"GuidedBackprop",
|
|
"GuidedGradCam",
|
|
"InputXGradient",
|
|
"IntegratedGradients",
|
|
"Lime",
|
|
"Occlusion",
|
|
"Saliency",
|
|
]
|
|
|
|
all__graphxai = [
|
|
"CAM",
|
|
"GradCAM",
|
|
"GNN_LRP",
|
|
"GradExplainer",
|
|
"GuidedBackPropagation",
|
|
"IntegratedGradients",
|
|
"PGExplainer",
|
|
"PGMExplainer",
|
|
"RandomExplainer",
|
|
"SubgraphX",
|
|
"GraphMASK",
|
|
]
|
|
|
|
all__own = ["EIXGNN", "SCGNN"]
|
|
|
|
all_fidelity = [
|
|
"fidelity_plus",
|
|
"fidelity_minus",
|
|
"fidelity_plus_prob",
|
|
"fidelity_minus_prob",
|
|
"infidelity_KL",
|
|
]
|
|
all_accuracy = [
|
|
"precision_score",
|
|
"jaccard_score",
|
|
"roc_auc_score",
|
|
"f1_score",
|
|
"accuracy_score",
|
|
]
|
|
|
|
all_robust = [
|
|
"gaussian_noise",
|
|
"add_edge",
|
|
"remove_edge",
|
|
"remove_node",
|
|
"pgd",
|
|
"fgsm",
|
|
]
|
|
|
|
|
|
class ExplainingOutline(object):
|
|
def __init__(self, explaining_cfg_path: str):
|
|
self.explaining_cfg_path = explaining_cfg_path
|
|
self.explaining_cfg = None
|
|
self.explainer_cfg_path = None
|
|
self.explainer_cfg = None
|
|
self.explaining_algorithm = None
|
|
self.cfg = None
|
|
self.model = None
|
|
self.dataset = None
|
|
self.model_info = None
|
|
|
|
self.load_explaining_cfg()
|
|
self.load_model_info()
|
|
self.load_cfg()
|
|
self.load_dataset()
|
|
self.load_model()
|
|
self.load_explainer_cfg()
|
|
self.load_explainer()
|
|
|
|
def load_model_info(self):
|
|
info = LoadModelInfo(
|
|
dataset_name=self.explaining_cfg.dataset.name,
|
|
model_dir=self.explaining_cfg.model.path,
|
|
which=self.explaining_cfg.model.ckpt,
|
|
)
|
|
self.model_info = info.set_info()
|
|
|
|
def load_cfg(self):
|
|
cfg.set_new_allowed(True)
|
|
cfg.merge_from_file(self.model_info["cfg_path"])
|
|
self.cfg = cfg
|
|
|
|
def load_explaining_cfg(self):
|
|
explaining_cfg.set_new_allowed(True)
|
|
explaining_cfg.merge_from_file(self.explaining_cfg_path)
|
|
self.explaining_cfg = explaining_cfg
|
|
|
|
def load_explainer_cfg(self):
|
|
if self.explaining_cfg is None:
|
|
self.load_explaining_cfg()
|
|
else:
|
|
if self.explaining_cfg.explainer.cfg == "default":
|
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
|
self.explainer_cfg = copy.copy(eixgnn_cfg)
|
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
|
self.explainer_cfg = copy.copy(scgnn_cfg)
|
|
else:
|
|
self.explainer_cfg = None
|
|
else:
|
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
|
eixgnn_cfg.set_new_allowed(True)
|
|
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
|
self.explainer_cfg = eixgnn_cfg
|
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
|
scgnn_cfg.set_new_allowed(True)
|
|
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
|
self.explainer_cfg = scgnn_cfg
|
|
|
|
def load_model(self):
|
|
if self.cfg is None:
|
|
self.load_cfg()
|
|
auto_select_device()
|
|
self.model = create_model()
|
|
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
|
if self.model is None:
|
|
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
|
|
|
def load_dataset(self):
|
|
if self.cfg is None:
|
|
self.load_cfg()
|
|
if self.explaining_cfg is None:
|
|
self.load_explaining_cfg()
|
|
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
|
|
raise ValueError(
|
|
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
|
)
|
|
self.dataset = create_dataset()
|
|
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
|
ind = self.explaining_cfg.dataset.specific_items
|
|
self.dataset = self.dataset[ind : ind + 1]
|
|
|
|
def load_explainer(self):
|
|
self.load_explainer_cfg()
|
|
if self.model is None:
|
|
self.load_model()
|
|
if self.dataset is None:
|
|
self.load_dataset()
|
|
|
|
name = self.explaining_cfg.explainer.name
|
|
if name in all__captum:
|
|
explaining_algorithm = CaptumWrapper(name)
|
|
elif name in all__graphxai:
|
|
explaining_algorithm = GraphXAIWrapper(
|
|
name,
|
|
in_channels=self.dataset.num_classes,
|
|
criterion=self.cfg.model.loss_fun,
|
|
)
|
|
elif name in all__own:
|
|
if name == "EIXGNN":
|
|
explaining_algorithm = EiXGNN(
|
|
L=self.explainer_cfg.L,
|
|
p=self.explainer_cfg.p,
|
|
importance_sampling_strategy=self.explainer_cfg.importance_sampling_strategy,
|
|
domain_similarity=self.explainer_cfg.domain_similarity,
|
|
signal_similarity=self.explainer_cfg.signal_similarity,
|
|
shap_val_approx=self.explainer_cfg.shapley_value_approx,
|
|
)
|
|
elif name == "SCGNN":
|
|
explaining_algorithm = SCGNN(
|
|
depth=self.explainer_cfg.depth,
|
|
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
|
score_map_norm=self.explainer_cfg.score_map_norm,
|
|
)
|
|
self.explaining_algorithm = explaining_algorithm
|
|
|
|
def load_metric(self):
|
|
if self.cfg is None:
|
|
self.load_cfg()
|
|
if self.explaining_cfg is None:
|
|
self.load_explaining_cfg()
|
|
|
|
if self.explaining_cfg.metrics.type == "all":
|
|
if self.explaining_cfg.dataset.name == "BASHAPES":
|
|
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
|
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
|
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
|
|
|
def load_attack(self):
|
|
if self.cfg is None:
|
|
self.load_cfg()
|
|
if self.explaining_cfg is None:
|
|
self.load_explaining_cfg()
|
|
all_rob_metrics = [Attack(name) for name in all_robust]
|
|
|
|
|
|
class FileManager(object):
|
|
def __init__(self):
|
|
pass
|
|
|
|
def save(obj: Any, path: str) -> None:
|
|
pass
|
|
|
|
|
|
PATH = "config_exp.yaml"
|
|
test = ExplainingOutline(explaining_cfg_path=PATH)
|