Update
This commit is contained in:
parent
702b724158
commit
3f4839fa59
|
@ -6,23 +6,10 @@ import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric import seed_everything
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.data.makedirs import makedirs
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
|
||||||
from torch_geometric.explain.explanation import Explanation
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
from yacs.config import CfgNode as CN
|
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
from explaining_framework.config.explainer_config.xgwt_config import xgwt_cfg
|
||||||
from explaining_framework.config.explaining_config import explaining_cfg
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||||
from explaining_framework.explainers.wrappers.from_graphxai import \
|
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||||
|
@ -45,6 +32,20 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
obj_config_to_str, read_json,
|
obj_config_to_str, read_json,
|
||||||
set_printing, write_json,
|
set_printing, write_json,
|
||||||
write_yaml)
|
write_yaml)
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
from xgwt.xgwt import XGWT
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -81,7 +82,7 @@ all__pyg = [
|
||||||
# "GNNExplainer",
|
# "GNNExplainer",
|
||||||
]
|
]
|
||||||
|
|
||||||
all__own = ["EIXGNN", "SCGNN"]
|
all__own = ["EIXGNN", "SCGNN", "XGWT"]
|
||||||
|
|
||||||
all_fidelity = [
|
all_fidelity = [
|
||||||
"fidelity_plus",
|
"fidelity_plus",
|
||||||
|
@ -218,6 +219,8 @@ class ExplainingOutline(object):
|
||||||
self.explainer_cfg = copy.copy(eixgnn_cfg)
|
self.explainer_cfg = copy.copy(eixgnn_cfg)
|
||||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||||
self.explainer_cfg = copy.copy(scgnn_cfg)
|
self.explainer_cfg = copy.copy(scgnn_cfg)
|
||||||
|
elif self.explaining_cfg.explainer.name == "XGWT":
|
||||||
|
self.explainer_cfg = copy.copy(xgwt_cfg)
|
||||||
else:
|
else:
|
||||||
self.explainer_cfg = CN()
|
self.explainer_cfg = CN()
|
||||||
else:
|
else:
|
||||||
|
@ -227,6 +230,9 @@ class ExplainingOutline(object):
|
||||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||||
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
self.explainer_cfg = scgnn_cfg
|
self.explainer_cfg = scgnn_cfg
|
||||||
|
elif self.explaining_cfg.explainer.name == "XGWT":
|
||||||
|
xgwt_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
|
self.explainer_cfg = xgwt_cfg
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
|
@ -290,6 +296,22 @@ class ExplainingOutline(object):
|
||||||
signal_similarity=self.explainer_cfg.signal_similarity,
|
signal_similarity=self.explainer_cfg.signal_similarity,
|
||||||
shap_val_approx=self.explainer_cfg.shapley_value_approx,
|
shap_val_approx=self.explainer_cfg.shapley_value_approx,
|
||||||
)
|
)
|
||||||
|
elif name == "XGWT":
|
||||||
|
explaining_algorithm = XGWT(
|
||||||
|
wav_approx=self.explainer_cfg.wav_approx,
|
||||||
|
wav_passband=self.explainer_cfg.wav_passband,
|
||||||
|
wav_normalization=self.explainer_cfg.wav_normalization,
|
||||||
|
num_candidates=self.explainer_cfg.num_candidates,
|
||||||
|
num_samples=self.explainer_cfg.num_samples,
|
||||||
|
c_procedure=self.explainer_cfg.c_procedure,
|
||||||
|
pred_thres_strat=self.explainer_cfg.pred_thres_strat,
|
||||||
|
CI_threshold=self.explainer_cfg.CI_threshold,
|
||||||
|
mixing=self.explainer_cfg.mixing,
|
||||||
|
pred_thres=self.explainer_cfg.pred_thres,
|
||||||
|
incl_prob=self.explainer_cfg.incl_prob,
|
||||||
|
top_k=self.explainer_cfg.top_k,
|
||||||
|
scales=self.explainer_cfg.scales,
|
||||||
|
)
|
||||||
elif name == "SCGNN":
|
elif name == "SCGNN":
|
||||||
explaining_algorithm = SCGNN(
|
explaining_algorithm = SCGNN(
|
||||||
depth=self.explainer_cfg.depth,
|
depth=self.explainer_cfg.depth,
|
||||||
|
|
Loading…
Reference in New Issue