explaining_framework/explaining_framework/utils/explaining/outline.py
2022-12-29 22:12:48 +01:00

229 lines
7.7 KiB
Python

import copy
from typing import Any
from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
from explaining_framework.explainers.wrappers.from_graphxai import \
GraphXAIWrapper
from explaining_framework.metric.accuracy import Accuracy
from explaining_framework.metric.fidelity import Fidelity
from explaining_framework.metric.robust import Attack
from explaining_framework.metric.sparsity import Sparsity
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
_load_ckpt)
all__captum = [
"LRP",
"DeepLift",
"DeepLiftShap",
"FeatureAblation",
"FeaturePermutation",
"GradientShap",
"GuidedBackprop",
"GuidedGradCam",
"InputXGradient",
"IntegratedGradients",
"Lime",
"Occlusion",
"Saliency",
]
all__graphxai = [
"CAM",
"GradCAM",
"GNN_LRP",
"GradExplainer",
"GuidedBackPropagation",
"IntegratedGradients",
"PGExplainer",
"PGMExplainer",
"RandomExplainer",
"SubgraphX",
"GraphMASK",
]
all__own = ["EIXGNN", "SCGNN"]
all_fidelity = [
"fidelity_plus",
"fidelity_minus",
"fidelity_plus_prob",
"fidelity_minus_prob",
"infidelity_KL",
]
all_accuracy = [
"precision_score",
"jaccard_score",
"roc_auc_score",
"f1_score",
"accuracy_score",
]
all_robust = [
"gaussian_noise",
"add_edge",
"remove_edge",
"remove_node",
"pgd",
"fgsm",
]
class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str):
self.explaining_cfg_path = explaining_cfg_path
self.explaining_cfg = None
self.explainer_cfg_path = None
self.explainer_cfg = None
self.explaining_algorithm = None
self.cfg = None
self.model = None
self.dataset = None
self.model_info = None
self.load_explaining_cfg()
self.load_model_info()
self.load_cfg()
self.load_dataset()
self.load_model()
self.load_explainer_cfg()
self.load_explainer()
def load_model_info(self):
info = LoadModelInfo(
dataset_name=self.explaining_cfg.dataset.name,
model_dir=self.explaining_cfg.model.path,
which=self.explaining_cfg.model.ckpt,
)
self.model_info = info.set_info()
def load_cfg(self):
cfg.set_new_allowed(True)
cfg.merge_from_file(self.model_info["cfg_path"])
self.cfg = cfg
def load_explaining_cfg(self):
explaining_cfg.set_new_allowed(True)
explaining_cfg.merge_from_file(self.explaining_cfg_path)
self.explaining_cfg = explaining_cfg
def load_explainer_cfg(self):
if self.explaining_cfg is None:
self.load_explaining_cfg()
else:
if self.explaining_cfg.explainer.cfg == "default":
if self.explaining_cfg.explainer.name == "EIXGNN":
self.explainer_cfg = copy.copy(eixgnn_cfg)
elif self.explaining_cfg.explainer.name == "SCGNN":
self.explainer_cfg = copy.copy(scgnn_cfg)
else:
self.explainer_cfg = None
else:
if self.explaining_cfg.explainer.name == "EIXGNN":
eixgnn_cfg.set_new_allowed(True)
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
self.explainer_cfg = eixgnn_cfg
elif self.explaining_cfg.explainer.name == "SCGNN":
scgnn_cfg.set_new_allowed(True)
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
self.explainer_cfg = scgnn_cfg
def load_model(self):
if self.cfg is None:
self.load_cfg()
auto_select_device()
self.model = create_model()
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
if self.model is None:
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
def load_dataset(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
raise ValueError(
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
)
self.dataset = create_dataset()
if isinstance(self.explaining_cfg.dataset.specific_items, int):
ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind : ind + 1]
def load_explainer(self):
self.load_explainer_cfg()
if self.model is None:
self.load_model()
if self.dataset is None:
self.load_dataset()
name = self.explaining_cfg.explainer.name
if name in all__captum:
explaining_algorithm = CaptumWrapper(name)
elif name in all__graphxai:
explaining_algorithm = GraphXAIWrapper(
name,
in_channels=self.dataset.num_classes,
criterion=self.cfg.model.loss_fun,
)
elif name in all__own:
if name == "EIXGNN":
explaining_algorithm = EiXGNN(
L=self.explainer_cfg.L,
p=self.explainer_cfg.p,
importance_sampling_strategy=self.explainer_cfg.importance_sampling_strategy,
domain_similarity=self.explainer_cfg.domain_similarity,
signal_similarity=self.explainer_cfg.signal_similarity,
shap_val_approx=self.explainer_cfg.shapley_value_approx,
)
elif name == "SCGNN":
explaining_algorithm = SCGNN(
depth=self.explainer_cfg.depth,
interest_map_norm=self.explainer_cfg.interest_map_norm,
score_map_norm=self.explainer_cfg.score_map_norm,
)
self.explaining_algorithm = explaining_algorithm
def load_metric(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
if self.explaining_cfg.metrics.type == "all":
if self.explaining_cfg.dataset.name == "BASHAPES":
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
def load_attack(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
all_rob_metrics = [Attack(name) for name in all_robust]
class FileManager(object):
def __init__(self):
pass
def save(obj: Any, path: str) -> None:
pass
PATH = "config_exp.yaml"
test = ExplainingOutline(explaining_cfg_path=PATH)