This commit is contained in:
araison 2023-02-12 12:54:19 +01:00
parent 702b724158
commit 3f4839fa59
1 changed files with 37 additions and 15 deletions

View File

@ -6,23 +6,10 @@ import os
from typing import Any from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from yacs.config import CfgNode as CN
from explaining_framework.config.explainer_config.eixgnn_config import \ from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
from explaining_framework.config.explainer_config.xgwt_config import xgwt_cfg
from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
from explaining_framework.explainers.wrappers.from_graphxai import \ from explaining_framework.explainers.wrappers.from_graphxai import \
@ -45,6 +32,20 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
obj_config_to_str, read_json, obj_config_to_str, read_json,
set_printing, write_json, set_printing, write_json,
write_yaml) write_yaml)
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from xgwt.xgwt import XGWT
from yacs.config import CfgNode as CN
all__captum = [ all__captum = [
"LRP", "LRP",
@ -81,7 +82,7 @@ all__pyg = [
# "GNNExplainer", # "GNNExplainer",
] ]
all__own = ["EIXGNN", "SCGNN"] all__own = ["EIXGNN", "SCGNN", "XGWT"]
all_fidelity = [ all_fidelity = [
"fidelity_plus", "fidelity_plus",
@ -218,6 +219,8 @@ class ExplainingOutline(object):
self.explainer_cfg = copy.copy(eixgnn_cfg) self.explainer_cfg = copy.copy(eixgnn_cfg)
elif self.explaining_cfg.explainer.name == "SCGNN": elif self.explaining_cfg.explainer.name == "SCGNN":
self.explainer_cfg = copy.copy(scgnn_cfg) self.explainer_cfg = copy.copy(scgnn_cfg)
elif self.explaining_cfg.explainer.name == "XGWT":
self.explainer_cfg = copy.copy(xgwt_cfg)
else: else:
self.explainer_cfg = CN() self.explainer_cfg = CN()
else: else:
@ -227,6 +230,9 @@ class ExplainingOutline(object):
elif self.explaining_cfg.explainer.name == "SCGNN": elif self.explaining_cfg.explainer.name == "SCGNN":
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
self.explainer_cfg = scgnn_cfg self.explainer_cfg = scgnn_cfg
elif self.explaining_cfg.explainer.name == "XGWT":
xgwt_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
self.explainer_cfg = xgwt_cfg
def load_model(self): def load_model(self):
if self.cfg is None: if self.cfg is None:
@ -290,6 +296,22 @@ class ExplainingOutline(object):
signal_similarity=self.explainer_cfg.signal_similarity, signal_similarity=self.explainer_cfg.signal_similarity,
shap_val_approx=self.explainer_cfg.shapley_value_approx, shap_val_approx=self.explainer_cfg.shapley_value_approx,
) )
elif name == "XGWT":
explaining_algorithm = XGWT(
wav_approx=self.explainer_cfg.wav_approx,
wav_passband=self.explainer_cfg.wav_passband,
wav_normalization=self.explainer_cfg.wav_normalization,
num_candidates=self.explainer_cfg.num_candidates,
num_samples=self.explainer_cfg.num_samples,
c_procedure=self.explainer_cfg.c_procedure,
pred_thres_strat=self.explainer_cfg.pred_thres_strat,
CI_threshold=self.explainer_cfg.CI_threshold,
mixing=self.explainer_cfg.mixing,
pred_thres=self.explainer_cfg.pred_thres,
incl_prob=self.explainer_cfg.incl_prob,
top_k=self.explainer_cfg.top_k,
scales=self.explainer_cfg.scales,
)
elif name == "SCGNN": elif name == "SCGNN":
explaining_algorithm = SCGNN( explaining_algorithm = SCGNN(
depth=self.explainer_cfg.depth, depth=self.explainer_cfg.depth,