diff --git a/explaining_framework/explainers/wrappers/from_captum.py b/explaining_framework/explainers/wrappers/from_captum.py index 4f847d8..cd29fda 100644 --- a/explaining_framework/explainers/wrappers/from_captum.py +++ b/explaining_framework/explainers/wrappers/from_captum.py @@ -187,22 +187,24 @@ class CaptumWrapper(ExplainerAlgorithm): raise ValueError(f"{self.name} is not a supported Captum method yet !") def _parse_attr(self, attr): + for i in range(len(attr)): + attr[i] = attr[i].squeeze() if self.mask_type == "node": - node_feat_mask = attr[0].squeeze(0) + node_mask = attr[0] edge_mask = None if self.mask_type == "edge": - node_feat_mask = None + node_mask = None edge_mask = attr[0] if self.mask_type == "node_and_edge": - node_feat_mask = attr[0].squeeze(0) + node_mask = attr[0] edge_mask = attr[1] else: raise ValueError edge_feat_mask = None - node_mask = None + node_feat_mask = None return node_mask, edge_mask, node_feat_mask, edge_feat_mask diff --git a/explaining_framework/metric/robust.py b/explaining_framework/metric/robust.py index 9901032..72b1e0d 100644 --- a/explaining_framework/metric/robust.py +++ b/explaining_framework/metric/robust.py @@ -2,9 +2,12 @@ import copy import torch import torch.nn.functional as F -from explaining_framework.metric.base import Metric from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node +from troch.nn import CrossEntropyLoss, MSELoss + +from explaining_framework.metric.base import Metric def compute_gradient(model, inp, target, loss): @@ -142,7 +145,15 @@ class Attack(Metric): "fgsm", ] self.dropout = dropout - self.loss = loss + if loss is None: + if cfg.model.loss_fun == "cross-entropy": + self.loss = CrossEntropyLoss() + if cfg.model.loss_fun == "mse": + self.loss = MSELoss() + else: + raise ValueError + else: + self.loss = loss self.load_metric(name) def _gaussian_noise(self, exp) -> Explanation: @@ -194,7 +205,6 @@ class Attack(Metric): if name == "remove_node": self.metric = self._load_remove_node() if name == "pgd": - print("set LOSS with cfg ") pgd = PGD(model=self.model, loss=self.loss) self.metric = lambda exp: pgd.forward( input=exp, @@ -206,7 +216,6 @@ class Attack(Metric): norm="inf", ) if name == "fgsm": - print("set LOSS with cfg ") fgsm = FGSM(model=self.model, loss=self.loss) self.metric = lambda exp: fgsm.forward( input=exp, target=exp.y, epsilon=1 diff --git a/explaining_framework/metric/sparsity.py b/explaining_framework/metric/sparsity.py index f54fb2a..ea0ee77 100644 --- a/explaining_framework/metric/sparsity.py +++ b/explaining_framework/metric/sparsity.py @@ -4,7 +4,21 @@ from explaining_framework.metric.base import Metric class Sparsity(Metric): def __init__(self, name): - super().__init__(name=name, model=None) + super().__init__(name=name) + self.authorized_metric = ['l0'] + self.metric = self.load_metric(name) + + def load_metric(self,name): + if name in self.authorized_metric: + if name == 'l0': + metric = lambda x : torch.mean(mask.float()).item() + else: + raise ValueError(f'{name} is not supported yet') + + + def forward(self, exp:Explanation) -> float: + out = {} + for k,v in exp.to_dict(): + if 'mask' in - def forward(self, mask): return torch.mean(mask.float()).item() diff --git a/explaining_framework/utils/explaining/explaining_outline.py b/explaining_framework/utils/explaining/outline.py similarity index 82% rename from explaining_framework/utils/explaining/explaining_outline.py rename to explaining_framework/utils/explaining/outline.py index b8459c4..af70e19 100644 --- a/explaining_framework/utils/explaining/explaining_outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -1,6 +1,15 @@ import copy +from typing import Any from eixgnn.eixgnn import EiXGNN +from scgnn.scgnn import SCGNN +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explainer +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device + from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -14,13 +23,6 @@ from explaining_framework.metric.robust import Attack from explaining_framework.metric.sparsity import Sparsity from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, _load_ckpt) -from scgnn.scgnn import SCGNN -from torch_geometric.data import Batch, Data -from torch_geometric.explain import Explainer -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device all__captum = [ "LRP", @@ -54,6 +56,30 @@ all__graphxai = [ all__own = ["EIXGNN", "SCGNN"] +all_fidelity = [ + "fidelity_plus", + "fidelity_minus", + "fidelity_plus_prob", + "fidelity_minus_prob", + "infidelity_KL", +] +all_accuracy = [ + "precision_score", + "jaccard_score", + "roc_auc_score", + "f1_score", + "accuracy_score", +] + +all_robust = [ + "gaussian_noise", + "add_edge", + "remove_edge", + "remove_node", + "pgd", + "fgsm", +] + class ExplainingOutline(object): def __init__(self, explaining_cfg_path: str): @@ -95,7 +121,7 @@ class ExplainingOutline(object): def load_explainer_cfg(self): if self.explaining_cfg is None: - self.explaining_cfg() + self.load_explaining_cfg() else: if self.explaining_cfg.explainer.cfg == "default": if self.explaining_cfg.explainer.name == "EIXGNN": @@ -127,7 +153,7 @@ class ExplainingOutline(object): if self.cfg is None: self.load_cfg() if self.explaining_cfg is None: - self.explaining_cfg() + self.load_explaining_cfg() if self.explaining_cfg.dataset.name != self.cfg.dataset.name: raise ValueError( f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." @@ -167,7 +193,33 @@ class ExplainingOutline(object): score_map_norm=self.explainer_cfg.score_map_norm, ) self.explaining_algorithm = explaining_algorithm - print(self.explaining_algorithm.__dict__) + + def load_metric(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.load_explaining_cfg() + + if self.explaining_cfg.metrics.type == "all": + if self.explaining_cfg.dataset.name == 'BASHAPES': + all_acc_metrics = [Accuracy(name) for name in all_accuracy] + all_fid_metrics = [Fidelity(name) for name in all_fidelity] + all_spa_metrics = [Sparsity(name) for name in all_sparsity] + + def load_attack(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.load_explaining_cfg() + all_rob_metrics = [Attack(name) for name in all_robust] + + +class FileManager(object): + def __init__(self): + pass + + def save(obj: Any, path: str) -> None: + pass PATH = "config_exp.yaml" diff --git a/explaining_framework/utils/explanation/adjust.py b/explaining_framework/utils/explanation/adjust.py index 0e641ee..cc76ec7 100644 --- a/explaining_framework/utils/explanation/adjust.py +++ b/explaining_framework/utils/explanation/adjust.py @@ -3,14 +3,63 @@ import copy from torch import FloatTensor from torch.nn import ReLU +class Adjust(object): + def __init__( + self, + apply_relu: bool = True, + apply_normalize: bool = True, + apply_project: bool = True, + apply_absolute: bool = False, + ): + self.apply_relu = apply_relu + self.apply_normalize = apply_normalize + self.apply_project = apply_project + self.apply_absolute = apply_absolute -def relu_mask(explanation: Explanation) -> Explanation: - relu = ReLU() - explanation_store = explanation._store - raw_data = copy.copy(explanation._store) - for k, v in explanation_store.items(): - if "mask" in k: - explanation_store[k] = relu(v) - explanation.__setattr__("raw_explanation", raw_data) - explanation.__setattr__("raw_explanation_transform", "relu") - return explanation + if self.apply_absolute and self.apply_relu: + self.apply_relu = False + + def forward(self, exp: Explanation) -> Explanation: + exp_ = exp.copy() + _store = exp_.to_dict() + for k, v in _store.items(): + if "mask" in k: + if self.apply_relu: + _store[k] = self.relu(v) + elif self.apply_absolute: + _store[k] = self.absolute(v) + elif self.apply_project: + if "edge" in k: + pass + else: + _store[k] = self.project(v) + elif self.apply_normalize: + _store[k] = self.normalize(v) + else: + continue + + return exp_ + + def relu(self, mask: FloatTensor) -> FloatTensor: + relu = ReLU() + mask_ = relu(mask) + return mask_ + + def normalize(self, mask: FloatTensor) -> FloatTensor: + norm = torch.norm(mask, p="inf") + if norm.item() > 0: + mask_ = mask / norm.item() + return mask_ + else: + return mask + + def project(self, mask: FloatTensor) -> FloatTensor: + if mask.ndim >= 2: + mask_ = torch.sum(mask, dim=1) + return mask_ + else: + return mask + + def absolute(self, mask: FloatTensor) -> FloatTensor: + mask_ = torch.abs(mask) + return mask_ diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index 9af1a57..e468a88 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -12,8 +12,9 @@ def explanation_verification(exp: Explanation) -> bool: for mask in masks: is_nan = mask.isnan().any().item() is_inf = mask.isinf().any().item() + is_const = mask.max()==mask.min() is_ok = exp.validate() - if is_nan or is_inf or not is_ok: + if is_nan or is_inf or not is_ok or is_const: is_good = False return is_good else: @@ -47,5 +48,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation data = exp.to_dict() for k, v in data.items(): if "_mask" in k and isinstance(v, torch.FloatTensor): - data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item() + norm =torch.norm(input=data[k], p=p, dim=None).item() + if norm.item()>0: + data[k] = data[k] / norm + return exp diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index 3ce2f2d..b74c2e5 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -1,5 +1,5 @@ import json - +import os import yaml @@ -23,3 +23,7 @@ def read_yaml(path: str) -> dict: def write_yaml(data: dict, path: str) -> None: with open(path, "w") as f: data = yaml.dump(data, f) + +def is_exists(path:str)-> bool: + return os.path.exists(path) +