Update
This commit is contained in:
parent
702b724158
commit
3f4839fa59
@ -6,23 +6,10 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||
from explaining_framework.config.explainer_config.xgwt_config import xgwt_cfg
|
||||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||
@ -45,6 +32,20 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||
obj_config_to_str, read_json,
|
||||
set_printing, write_json,
|
||||
write_yaml)
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from xgwt.xgwt import XGWT
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
all__captum = [
|
||||
"LRP",
|
||||
@ -81,7 +82,7 @@ all__pyg = [
|
||||
# "GNNExplainer",
|
||||
]
|
||||
|
||||
all__own = ["EIXGNN", "SCGNN"]
|
||||
all__own = ["EIXGNN", "SCGNN", "XGWT"]
|
||||
|
||||
all_fidelity = [
|
||||
"fidelity_plus",
|
||||
@ -218,6 +219,8 @@ class ExplainingOutline(object):
|
||||
self.explainer_cfg = copy.copy(eixgnn_cfg)
|
||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||
self.explainer_cfg = copy.copy(scgnn_cfg)
|
||||
elif self.explaining_cfg.explainer.name == "XGWT":
|
||||
self.explainer_cfg = copy.copy(xgwt_cfg)
|
||||
else:
|
||||
self.explainer_cfg = CN()
|
||||
else:
|
||||
@ -227,6 +230,9 @@ class ExplainingOutline(object):
|
||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||
self.explainer_cfg = scgnn_cfg
|
||||
elif self.explaining_cfg.explainer.name == "XGWT":
|
||||
xgwt_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||
self.explainer_cfg = xgwt_cfg
|
||||
|
||||
def load_model(self):
|
||||
if self.cfg is None:
|
||||
@ -290,6 +296,22 @@ class ExplainingOutline(object):
|
||||
signal_similarity=self.explainer_cfg.signal_similarity,
|
||||
shap_val_approx=self.explainer_cfg.shapley_value_approx,
|
||||
)
|
||||
elif name == "XGWT":
|
||||
explaining_algorithm = XGWT(
|
||||
wav_approx=self.explainer_cfg.wav_approx,
|
||||
wav_passband=self.explainer_cfg.wav_passband,
|
||||
wav_normalization=self.explainer_cfg.wav_normalization,
|
||||
num_candidates=self.explainer_cfg.num_candidates,
|
||||
num_samples=self.explainer_cfg.num_samples,
|
||||
c_procedure=self.explainer_cfg.c_procedure,
|
||||
pred_thres_strat=self.explainer_cfg.pred_thres_strat,
|
||||
CI_threshold=self.explainer_cfg.CI_threshold,
|
||||
mixing=self.explainer_cfg.mixing,
|
||||
pred_thres=self.explainer_cfg.pred_thres,
|
||||
incl_prob=self.explainer_cfg.incl_prob,
|
||||
top_k=self.explainer_cfg.top_k,
|
||||
scales=self.explainer_cfg.scales,
|
||||
)
|
||||
elif name == "SCGNN":
|
||||
explaining_algorithm = SCGNN(
|
||||
depth=self.explainer_cfg.depth,
|
||||
|
Loading…
Reference in New Issue
Block a user