From 9ad5adb33e6e059186baec8f497f35b73bd3765e Mon Sep 17 00:00:00 2001 From: araison Date: Mon, 2 Jan 2023 23:37:40 +0100 Subject: [PATCH] Fixing bugs and adding new features --- .../config/explaining_config.py | 5 +- .../explainers/wrappers/test.py | 68 +++++------ explaining_framework/metric/base.py | 9 +- explaining_framework/metric/fidelity.py | 33 +++++- explaining_framework/metric/robust.py | 17 ++- explaining_framework/metric/sparsity.py | 2 +- .../stats/graph/graph_stat.py | 5 +- .../utils/explaining/load_ckpt.py | 2 +- .../utils/explaining/outline.py | 30 +++-- .../utils/explanation/adjust.py | 7 +- explaining_framework/utils/explanation/io.py | 9 +- explaining_framework/utils/io.py | 5 +- main.py | 111 +++++++++++------- 13 files changed, 183 insertions(+), 120 deletions(-) diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index 140362c..6fdd3d8 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -114,7 +114,7 @@ def set_cfg(explaining_cfg): explaining_cfg.threshold_config.threshold_type = None - explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7] + explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)] explaining_cfg.threshold_config.relu_and_normalize = True @@ -128,8 +128,7 @@ def set_cfg(explaining_cfg): explaining_cfg.metrics.force = False explaining_cfg.attack = CN() - explaining_cfg.attack.name = 'all' - + explaining_cfg.attack.name = "all" explaining_cfg.accelerator = "auto" diff --git a/explaining_framework/explainers/wrappers/test.py b/explaining_framework/explainers/wrappers/test.py index 4ed0970..3b0d3aa 100644 --- a/explaining_framework/explainers/wrappers/test.py +++ b/explaining_framework/explainers/wrappers/test.py @@ -31,16 +31,16 @@ __all__captum = [ __all__graphxai = [ "CAM", - # "GradCAM", - # "GNN_LRP", - # "GradExplainer", - # "GuidedBackPropagation", - # "IntegratedGradients", - # "PGExplainer", - # "PGMExplainer", - # "RandomExplainer", - # "SubgraphX", - # "GraphMASK", + "GradCAM", + "GNN_LRP", + "GradExplainer", + "GuidedBackPropagation", + "IntegratedGradients", + "PGExplainer", + "PGMExplainer", + "RandomExplainer", + "SubgraphX", + "GraphMASK", ] @@ -82,12 +82,12 @@ for epoch in range(1, 2): target = torch.LongTensor([[0]]) for kind in ["graph"]: - for name in __all__graphxai: + for name in __all__graphxai + __all__captum: if name in __all__captum: explaining_algorithm = CaptumWrapper(name) elif name in __all__graphxai: explaining_algorithm = GraphXAIWrapper( - name, in_channels=in_channels, criterion="cross-entropy" + name, in_channels=in_channels, criterion="cross_entropy" ) print(name) @@ -105,7 +105,7 @@ for kind in ["graph"]: task_level=kind, return_type="raw", ), - threshold_config=dict(threshold_type="hard", value=0.5), + # threshold_config=dict(threshold_type=None, value=0.5), ) explanation = explainer( x=batch.x, @@ -117,29 +117,29 @@ for kind in ["graph"]: # explanation.__setattr__( # "model_prediction", explainer.get_prediction(x, edge_index) # ) - explanation_threshold = explanation._apply_masks( - node_mask=torch.ones_like(explanation.node_mask).bool() - ) + # explanation_threshold = explanation._apply_masks( + # node_mask=torch.ones_like(explanation.node_mask).bool() + # ) - print(explanation_threshold.__dict__) + # print(explanation_threshold.__dict__) - for f_name in [ - "gaussian_noise", - "add_edge", - "remove_edge", - "remove_node", - "pgd", - "fgsm", - ]: - print(f_name) - acc = Attack(name=f_name, model=model, loss=loss) - # gt = torch.ones_like(explanation_threshold.node_mask) / 2 - # mask = explanation_threshold.node_mask.bool() - # target = (1 - gt).bool() - # target[1] = False - # print(mask, target) - out = acc.forward(explanation) - print(out) + # for f_name in [ + # "gaussian_noise", + # "add_edge", + # "remove_edge", + # "remove_node", + # "pgd", + # "fgsm", + # ]: + # print(f_name) + # acc = Attack(name=f_name, model=model, loss=loss) + # # gt = torch.ones_like(explanation_threshold.node_mask) / 2 + # # mask = explanation_threshold.node_mask.bool() + # # target = (1 - gt).bool() + # # target[1] = False + # # print(mask, target) + # out = acc.forward(explanation) + # print(out) except Exception as e: traceback.print_exc() diff --git a/explaining_framework/metric/base.py b/explaining_framework/metric/base.py index 5569ec6..fc18ed5 100644 --- a/explaining_framework/metric/base.py +++ b/explaining_framework/metric/base.py @@ -39,14 +39,9 @@ class Metric(ABC): **kwargs (optional): Additional keyword arguments passed to the model. """ - training = self.model.training - self.model.eval() + print(args, kwargs) with torch.no_grad(): - out = self.model(*args, **kwargs) - - self.model.train(training) + out = self.model(*args, **kwargs)[0] return out - - diff --git a/explaining_framework/metric/fidelity.py b/explaining_framework/metric/fidelity.py index 96cf944..48e8392 100644 --- a/explaining_framework/metric/fidelity.py +++ b/explaining_framework/metric/fidelity.py @@ -1,9 +1,9 @@ import torch import torch.nn.functional as F +from torch import Tensor from torch.nn import KLDivLoss, Softmax from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg -from torch import Tensor from explaining_framework.metric.base import Metric @@ -118,9 +118,19 @@ class Fidelity(Metric): ) pos_fidelity = self._fidelity_plus_prob(exp) neg_fidelity = self._fidelity_minus_prob(exp) - - denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity)) - return 1.0 / denom + if ( + pos_fidelity == 0 + or pos_fidelity == 1 + or neg_fidelity == 0 + or neg_fidelity == 1 + ): + return None + else: + denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity)) + if denom == 0: + return None + else: + return 1.0 / denom def _characterization( self, @@ -136,8 +146,19 @@ class Fidelity(Metric): pos_fidelity = self._fidelity_plus(exp) neg_fidelity = self._fidelity_minus(exp) - denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity)) - return 1.0 / denom + if ( + pos_fidelity == 0 + or pos_fidelity == 1 + or neg_fidelity == 0 + or neg_fidelity == 1 + ): + return None + else: + denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity)) + if denom == 0: + return None + else: + return 1.0 / denom def score(self, exp): self.exp_sub = exp.get_explanation_subgraph() diff --git a/explaining_framework/metric/robust.py b/explaining_framework/metric/robust.py index b4348d9..5c7aaee 100644 --- a/explaining_framework/metric/robust.py +++ b/explaining_framework/metric/robust.py @@ -26,7 +26,7 @@ class FGSM(Metric): lower_bound: float = float("-inf"), upper_bound: float = float("inf"), ): - super().__init__(name=name, model=model) + super().__init__(name="fgsm", model=model) self.model = model self.loss = loss self.lower_bound = lower_bound @@ -51,6 +51,9 @@ class FGSM(Metric): ) return input_ + def load_metric(self): + pass + class PGD(Metric): def __init__( @@ -60,7 +63,7 @@ class PGD(Metric): lower_bound: float = float("-inf"), upper_bound: float = float("inf"), ): - super().__init__(name=name, model=model) + super().__init__(name="pgd", model=model) self.model = model self.loss = loss self.lower_bound = lower_bound @@ -105,6 +108,9 @@ class PGD(Metric): perturbed_input.x = self.bound(perturbed_input.x).detach() return perturbed_input + def load_metric(self): + pass + def _random_point( self, center: torch.Tensor, radius: float, norm: str ) -> torch.Tensor: @@ -135,6 +141,7 @@ class Attack(Metric): dropout: float = 0.5, loss: torch.nn = None, ): + super().__init__(name=name, model=model) self.name = name self.model = model @@ -148,12 +155,12 @@ class Attack(Metric): ] self.dropout = dropout if loss is None: - if cfg.model.loss_fun == "cross-entropy": + if cfg.model.loss_fun == "cross_entropy": self.loss = CrossEntropyLoss() - if cfg.model.loss_fun == "mse": + elif cfg.model.loss_fun == "mse": self.loss = MSELoss() else: - raise ValueError + raise ValueError(f"{loss} is not supported yet") else: self.loss = loss self.load_metric(name) diff --git a/explaining_framework/metric/sparsity.py b/explaining_framework/metric/sparsity.py index 067f0c6..069564a 100644 --- a/explaining_framework/metric/sparsity.py +++ b/explaining_framework/metric/sparsity.py @@ -19,7 +19,7 @@ class Sparsity(Metric): def forward(self, exp: Explanation) -> float: out = {} - for k, v in exp.to_dict(): + for k, v in exp.to_dict().items(): if "mask" in k and v.dtype == torch.bool: out[k] = torch.mean(mask.float()).item() return out diff --git a/explaining_framework/stats/graph/graph_stat.py b/explaining_framework/stats/graph/graph_stat.py index 804c5eb..41cc3ca 100644 --- a/explaining_framework/stats/graph/graph_stat.py +++ b/explaining_framework/stats/graph/graph_stat.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import copy import types from inspect import getmembers, isfunction, signature @@ -152,7 +153,7 @@ class GraphStat(object): return maps def __call__(self, data): - data_ = data.__copy__() + data_ = copy.copy(data) datahash = hash(data.__repr__) stats = {} for k, v in self.maps.items(): @@ -160,7 +161,7 @@ class GraphStat(object): _data_ = to_networkx(data) _data_ = _data_.to_undirected() elif k == "torch_geometric": - _data_ = data.__copy__() + _data_ = copy.copy(data) for name, func in v.items(): try: val = func(_data_) diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index 288246d..d2ea05e 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -122,7 +122,7 @@ class LoadModelInfo(object): if self.info is None: self.set_info() - model_name = os.path.basename(self.info["xp_dir_name"]) + model_name = os.path.basename(self.info["xp_dir_path"]) model_seed = self.info["seed"] epoch = os.path.basename(self.info["ckpt_path"]) model_signature = "-".join( diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index d80d0a5..8fb333b 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -4,12 +4,12 @@ from typing import Any from eixgnn.eixgnn import EiXGNN from scgnn.scgnn import SCGNN from torch_geometric.data import Batch, Data -from torch_geometric.loader.dataloader import DataLoader from torch_geometric.explain import Explainer from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.model_builder import cfg, create_model from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg @@ -53,6 +53,7 @@ all__graphxai = [ "RandomExplainer", "SubgraphX", "GraphMASK", + "GNNExplainer", ] all__own = ["EIXGNN", "SCGNN"] @@ -82,6 +83,7 @@ all_robust = [ "pgd", "fgsm", ] +all_sparsity = ["l0"] class ExplainingOutline(object): @@ -108,6 +110,7 @@ class ExplainingOutline(object): self.load_explainer() self.load_metric() self.load_attack() + self.load_dataset_to_dataloader() def load_model_info(self): info = LoadModelInfo( @@ -171,7 +174,12 @@ class ExplainingOutline(object): if isinstance(self.explaining_cfg.dataset.specific_items, int): ind = self.explaining_cfg.dataset.specific_items self.dataset = self.dataset[ind : ind + 1] - self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1) + elif isinstance(self.explaining_cfg.dataset.specific_items, list): + ind = self.explaining_cfg.dataset.specific_items + self.dataset = self.dataset[ind] + + def load_dataset_to_dataloader(self): + self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) def load_explainer(self): self.load_explainer_cfg() @@ -217,18 +225,20 @@ class ExplainingOutline(object): if self.explaining_cfg is None: self.load_explaining_cfg() - name_ = self.explaining_cfg.metrics.type + name_ = self.explaining_cfg.metrics.name if name_ == "all": - all_fid_metrics = [Fidelity(name) for name in all_fidelity] + all_fid_metrics = [ + Fidelity(name=name, model=self.model) for name in all_fidelity + ] all_spa_metrics = [Sparsity(name) for name in all_sparsity] - self.metrics = all_acc_metrics + all_fid_metrics + self.metrics = all_spa_metrics + all_fid_metrics if self.explaining_cfg.dataset.name == "BASHAPES": all_acc_metrics = [Accuracy(name) for name in all_accuracy] self.metrics = self.metrics + all_acc_metrics elif name_ in all_fidelity: - self.metrics = [Fidelity(name_)] + self.metrics = [Fidelity(name=name_, model=self.model)] elif name_ in all_sparsity: self.metrics = [Sparsity(name_)] elif name_ in all_accuracy: @@ -250,11 +260,13 @@ class ExplainingOutline(object): self.load_explaining_cfg() name_ = self.explaining_cfg.attack.name if name_ == "all": - all_rob_metrics = [Attack(name) for name in all_robust] + all_rob_metrics = [ + Attack(name=name, model=self.model) for name in all_robust + ] self.attacks = all_rob_metrics elif name_ in all_robust: - self.attacks = [Attack(name_)] + self.attacks = [Attack(name=name_, model=self.model)] elif name_ is None: - slef.attacks = [] + self.attacks = [] else: raise ValueError(f"{name_} is an Attack method that is not supported yet") diff --git a/explaining_framework/utils/explanation/adjust.py b/explaining_framework/utils/explanation/adjust.py index cc76ec7..1aa639f 100644 --- a/explaining_framework/utils/explanation/adjust.py +++ b/explaining_framework/utils/explanation/adjust.py @@ -1,7 +1,10 @@ import copy +import torch from torch import FloatTensor from torch.nn import ReLU +from torch_geometric.explain.explanation import Explanation + class Adjust(object): def __init__( @@ -20,7 +23,7 @@ class Adjust(object): self.apply_relu = False def forward(self, exp: Explanation) -> Explanation: - exp_ = exp.copy() + exp_ = copy.copy(exp) _store = exp_.to_dict() for k, v in _store.items(): if "mask" in k: @@ -61,5 +64,7 @@ class Adjust(object): return mask def absolute(self, mask: FloatTensor) -> FloatTensor: + print("######################### MASK") + print(mask) mask_ = torch.abs(mask) return mask_ diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index e468a88..6a31567 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -2,6 +2,7 @@ import copy import json import os +import torch from torch_geometric.data import Data from torch_geometric.explain.explanation import Explanation @@ -12,7 +13,7 @@ def explanation_verification(exp: Explanation) -> bool: for mask in masks: is_nan = mask.isnan().any().item() is_inf = mask.isinf().any().item() - is_const = mask.max()==mask.min() + is_const = mask.max() == mask.min() is_ok = exp.validate() if is_nan or is_inf or not is_ok or is_const: is_good = False @@ -25,8 +26,10 @@ def explanation_verification(exp: Explanation) -> bool: def save_explanation(exp: Explanation, path: str) -> None: data = copy.copy(exp).to_dict() for k, v in data.items(): + print(k, v) if isinstance(v, torch.Tensor): data[k] = v.detach().cpu().tolist() + with open(path, "w") as f: json.dump(data, f) @@ -48,8 +51,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation data = exp.to_dict() for k, v in data.items(): if "_mask" in k and isinstance(v, torch.FloatTensor): - norm =torch.norm(input=data[k], p=p, dim=None).item() - if norm.item()>0: + norm = torch.norm(input=data[k], p=p, dim=None).item() + if norm.item() > 0: data[k] = data[k] / norm return exp diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index 83f0960..9624562 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -31,11 +31,8 @@ def is_exists(path: str) -> bool: def get_obj_config(obj): - config = {k: getattr(obj, k) for k in dir(obj)} config = { - k: v - for k, v in config.items() - if isinstance(v, (int, float, str, bool)) or v is None + k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool)) } return config diff --git a/main.py b/main.py index e52bb1f..99be02b 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,11 @@ # -*- coding: utf-8 -*- # +import copy import os import time +import torch from torch_geometric import seed_everything from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer @@ -16,27 +18,39 @@ from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust -from explaining_framework.utils.io import (obj_config_to_str, read_json, - write_json, write_yaml) +from explaining_framework.utils.explanation.io import ( + explanation_verification, load_explanation, save_explanation) +from explaining_framework.utils.io import (is_exists, obj_config_to_str, + read_json, write_json, write_yaml) # inference, time, force, -def get_pred(explanation, force=False): - dict_ = explanation.to_dict() - if dict_.get("pred") is None or dict_.get("pred_masked") or force: - pred = explainer.get_prediction(explanation) +def get_pred(explainer, explanation): + pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[ + 0 + ] + setattr(explanation, "pred", pred) + data = explanation.to_dict() + if not data.get("node_mask") is None or not data.get("edge_mask") is None: pred_masked = explainer.get_masked_prediction( x=explanation.x, edge_index=explanation.edge_index, - node_mask=explanation.node_mask, - edge_mask=explanation.edge_mask, - ) - explanation.__setattr__("pred", pred) - explanation.__setattr__("pred_masked", pred_masked) - return explanation - else: - return explanation + node_mask=data.get("node_mask"), + edge_mask=data.get("edge_mask"), + )[0] + setattr(explanation, "pred_exp", pred_masked) + + +def get_explanation(explainer, item): + explanation = explainer( + x=item.x, + edge_index=item.edge_index, + index=int(item.y), + target=item.y, + ) + assert explanation_verification(explanation) + return explanation if __name__ == "__main__": @@ -45,8 +59,9 @@ if __name__ == "__main__": auto_select_device() # Load components - dataset = outline.dataset.to(cfg.accelerator) + dataset = outline.dataset model = outline.model.to(cfg.accelerator) + model = model.eval() model_info = outline.model_info metrics = outline.metrics explaining_algorithm = outline.explaining_algorithm @@ -87,53 +102,57 @@ if __name__ == "__main__": return_type=explaining_cfg.model_config.return_type, ), ) + if not explaining_cfg.dataset.specific_items is None: + indexes = explaining_cfg.dataset.specific_items + else: + indexes = range(len(dataset)) # Save explaining configuration - for index, item in enumerate(dataset): + for index, item in zip(indexes, dataset): + item = item.to(cfg.accelerator) save_raw_path = os.path.join(global_path, "raw") makedirs(save_raw_path) explanation_path = os.path.join(save_raw_path, f"{index}.json") if is_exists(explanation_path): if explaining_cfg.explainer.force: - explanation = explainer( - x=item.x, - edge_index=item.edge_index, - index=item.y, - target=item.y, - ) + explanation = get_explanation(explainer, item) else: explanation = load_explanation(explanation_path) else: - explanation = explainer( - x=item.x, - edge_index=item.edge_index, - index=item.y, - target=item.y, - ) - explanation = get_pred(explanation, force=False) + explanation = get_explanation(explainer, item) + + explanation = explanation.to(cfg.accelerator) + get_pred(explainer=explainer, explanation=explanation) + save_explanation(explanation, explanation_path) for apply_relu in [True, False]: for apply_absolute in [True, False]: adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) - save_raw_path = os.path.join( + save_raw_path_ = os.path.join( global_path, f"adjust-{obj_config_to_str(adjust)}" ) - makedirs(save_raw_path) - explanation = adjust.forward(explanation) - explanation_path = os.path.join(save_raw_path, f"{index}.json") - explanation = get_pred(explanation, force=True) - save_explanation(explanation, explanation_path) + explanation__ = copy.copy(explanation).to(cfg.accelerator) + makedirs(save_raw_path_) + explanation = adjust.forward(explanation__) + explanation_path = os.path.join(save_raw_path_, f"{index}.json") + get_pred(explainer, explanation__) + save_explanation(explanation__, explanation_path) for threshold_approach in ["hard", "topk", "topk_hard"]: - for threshold_value in explaining_cfg.threshold_config.value: + if threshold_approach == "hard": + threshold_values = explaining_cfg.threshold_config.value + elif "topk" in threshold_approach: + threshold_values = [3, 5, 10, 20] + for threshold_value in threshold_values: masking_path = os.path.join( - save_raw_path, - f"threshold={threshold_approach}-value={value}", + save_raw_path_, + f"threshold={threshold_approach}-value={threshold_value}", ) + makedirs(masking_path) exp_threshold_path = os.path.join(masking_path, f"{index}.json") if is_exists(exp_threshold_path): - explanation = load_explanation(exp_threshold_path) + exp_threshold = load_explanation(exp_threshold_path) else: threshold_conf = { "threshold_type": threshold_approach, @@ -143,17 +162,21 @@ if __name__ == "__main__": threshold_conf ) - expl = copy.copy(explanation) + expl = copy.copy(explanation__).to(cfg.accelerator) exp_threshold = explainer._post_process(expl) - exp_threshold = get_pred(exp_threshold, force=True) - - save_explanation(exp_threshold, exp_threshold_path) + exp_threshold = exp_threshold.to(cfg.accelerator) + get_pred(explainer, exp_threshold) + save_explanation(exp_threshold, exp_threshold_path) for metric in metrics: metric_path = os.path.join( masking_path, f"{obj_config_to_str(metric)}" ) + makedirs(metric_path) if is_exists(os.path.join(metric_path, f"{index}.json")): continue else: out = metric.forward(exp_threshold) - write_json({f"{metric.name}": out}) + write_json( + {f"{metric.name}": out}, + os.path.join(metric_path, f"{index}.json"), + )