diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 590f4c4..70bd2da 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -6,23 +6,10 @@ import os from typing import Any from eixgnn.eixgnn import EiXGNN -from scgnn.scgnn import SCGNN -from torch_geometric import seed_everything -from torch_geometric.data import Batch, Data -from torch_geometric.data.makedirs import makedirs -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.explain.explanation import Explanation -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset, create_dataset2 -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader -from yacs.config import CfgNode as CN - from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg +from explaining_framework.config.explainer_config.xgwt_config import xgwt_cfg from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper from explaining_framework.explainers.wrappers.from_graphxai import \ @@ -45,6 +32,20 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, obj_config_to_str, read_json, set_printing, write_json, write_yaml) +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.data.makedirs import makedirs +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset, create_dataset2 +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader +from xgwt.xgwt import XGWT +from yacs.config import CfgNode as CN all__captum = [ "LRP", @@ -81,7 +82,7 @@ all__pyg = [ # "GNNExplainer", ] -all__own = ["EIXGNN", "SCGNN"] +all__own = ["EIXGNN", "SCGNN", "XGWT"] all_fidelity = [ "fidelity_plus", @@ -218,6 +219,8 @@ class ExplainingOutline(object): self.explainer_cfg = copy.copy(eixgnn_cfg) elif self.explaining_cfg.explainer.name == "SCGNN": self.explainer_cfg = copy.copy(scgnn_cfg) + elif self.explaining_cfg.explainer.name == "XGWT": + self.explainer_cfg = copy.copy(xgwt_cfg) else: self.explainer_cfg = CN() else: @@ -227,6 +230,9 @@ class ExplainingOutline(object): elif self.explaining_cfg.explainer.name == "SCGNN": scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) self.explainer_cfg = scgnn_cfg + elif self.explaining_cfg.explainer.name == "XGWT": + xgwt_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) + self.explainer_cfg = xgwt_cfg def load_model(self): if self.cfg is None: @@ -290,6 +296,22 @@ class ExplainingOutline(object): signal_similarity=self.explainer_cfg.signal_similarity, shap_val_approx=self.explainer_cfg.shapley_value_approx, ) + elif name == "XGWT": + explaining_algorithm = XGWT( + wav_approx=self.explainer_cfg.wav_approx, + wav_passband=self.explainer_cfg.wav_passband, + wav_normalization=self.explainer_cfg.wav_normalization, + num_candidates=self.explainer_cfg.num_candidates, + num_samples=self.explainer_cfg.num_samples, + c_procedure=self.explainer_cfg.c_procedure, + pred_thres_strat=self.explainer_cfg.pred_thres_strat, + CI_threshold=self.explainer_cfg.CI_threshold, + mixing=self.explainer_cfg.mixing, + pred_thres=self.explainer_cfg.pred_thres, + incl_prob=self.explainer_cfg.incl_prob, + top_k=self.explainer_cfg.top_k, + scales=self.explainer_cfg.scales, + ) elif name == "SCGNN": explaining_algorithm = SCGNN( depth=self.explainer_cfg.depth,