diff --git a/explaining_framework/config/explainer_config/eixgnn_config.py b/explaining_framework/config/explainer_config/eixgnn_config.py index f36d3c0..169eb7c 100644 --- a/explaining_framework/config/explainer_config/eixgnn_config.py +++ b/explaining_framework/config/explainer_config/eixgnn_config.py @@ -37,12 +37,12 @@ def set_eixgnn_cfg(eixgnn_cfg): return eixgnn_cfg eixgnn_cfg.seed = 0 - eixgnn_cfg.L = 50 - eixgnn_cfg.p = 0.5 - eixgnn_cfg.importance_sampling_strategy = "node" + eixgnn_cfg.L = 5 + eixgnn_cfg.p = 0.1 + eixgnn_cfg.importance_sampling_strategy = "neighborhood" eixgnn_cfg.domain_similarity = "relative_edge_density" eixgnn_cfg.signal_similarity = "KL" - eixgnn_cfg.shapley_value_approx = 100 + eixgnn_cfg.shapley_value_approx = 20 def assert_eixgnn_cfg(eixgnn_cfg): diff --git a/explaining_framework/config/explainer_config/scgnn_config.py b/explaining_framework/config/explainer_config/scgnn_config.py index 1766761..c665dd9 100644 --- a/explaining_framework/config/explainer_config/scgnn_config.py +++ b/explaining_framework/config/explainer_config/scgnn_config.py @@ -39,6 +39,7 @@ def set_scgnn_cfg(scgnn_cfg): scgnn_cfg.depth = "all" scgnn_cfg.interest_map_norm = True scgnn_cfg.score_map_norm = True + scgnn_cfg.target_baseline = "inference" def assert_cfg(scgnn_cfg): diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index ce95265..9641c25 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -57,9 +57,7 @@ def set_cfg(explaining_cfg): explaining_cfg.dataset.name = "Cora" - explaining_cfg.dataset.items = None - - explaining_cfg.run_topological_stat = True + explaining_cfg.dataset.item = None # ----------------------------------------------------------------------- # # Model options @@ -116,7 +114,7 @@ def set_cfg(explaining_cfg): explaining_cfg.threshold.config.type = "all" explaining_cfg.threshold.value = CN() - explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)] + explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)] explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50] # which objectives metrics to computes, either all or one in particular if implemented @@ -131,7 +129,7 @@ def set_cfg(explaining_cfg): # Whether or not recomputing metrics if they already exist explaining_cfg.adjust = CN() - explaining_cfg.adjust.strategy = "rpn" + explaining_cfg.adjust.strategy = "rpns" explaining_cfg.attack = CN() explaining_cfg.attack.name = "all" diff --git a/explaining_framework/explainers/wrappers/from_graphxai.py b/explaining_framework/explainers/wrappers/from_graphxai.py index 04ac2fa..ceb8dbd 100644 --- a/explaining_framework/explainers/wrappers/from_graphxai.py +++ b/explaining_framework/explainers/wrappers/from_graphxai.py @@ -37,10 +37,7 @@ def _load_GNN_LRP(model): def _load_GuidedBackPropagation(model, criterion): - # return lambda model: GuidedBP(model, criterion) - raise ValueError( - "GraphXAI GuidedBackPropagation is discarded since already available in Captum for Pytorch Geometric (see CaptumWrapper)" - ) + return GuidedBP(model, criterion) def _load_IntegratedGradients(model, criterion): @@ -106,8 +103,8 @@ class GraphXAIWrapper(ExplainerAlgorithm): "GradCAM", "GNN_LRP", "GradExplainer", - "GuidedBP", - "IntegratedGradExplainer", + "GuidedBackPropagation", + "IntegratedGraddients", "PGExplainer", "PGMExplainer", "RandomExplainer", @@ -234,10 +231,12 @@ class GraphXAIWrapper(ExplainerAlgorithm): index: Optional[Union[int, Tensor]] = None, **kwargs, ): + mask_type = self._get_mask_type() self.graphxai_method = self._load_graphxai_method(model) if self.model_config.task_level == ModelTaskLevel.node: + attr = self.graphxai_method.get_explanation_node( x=x, edge_index=edge_index, @@ -245,18 +244,26 @@ class GraphXAIWrapper(ExplainerAlgorithm): node_idx=index, y=target, ) + elif self.model_config.task_level == ModelTaskLevel.graph: + attr = self.graphxai_method.get_explanation_graph( x=x, edge_index=edge_index, label=target, y=target, ) + elif self.model_config.task_level == ModelTaskLevel.edge: + attr = self.graphxai_method.get_explanation_link(*args, **kwargs) + else: + raise ValueError(f"{self.model_config.task_level} is not supported yet") + node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr) + return Explanation( x=x, edge_index=edge_index, diff --git a/explaining_framework/explainers/wrappers/test.py b/explaining_framework/explainers/wrappers/test.py index 3b0d3aa..888b855 100644 --- a/explaining_framework/explainers/wrappers/test.py +++ b/explaining_framework/explainers/wrappers/test.py @@ -2,16 +2,16 @@ import traceback import torch import torch.nn as nn -from explaining_framework.metric.accuracy import Accuracy -from explaining_framework.metric.fidelity import Fidelity -from explaining_framework.metric.robust import Attack -from explaining_framework.metric.sparsity import Sparsity +from from_captum import CaptumWrapper +from from_graphxai import GraphXAIWrapper from torch_geometric.data import Batch, Data from torch_geometric.explain import Explainer from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool -from from_captum import CaptumWrapper -from from_graphxai import GraphXAIWrapper +from explaining_framework.metric.accuracy import Accuracy +from explaining_framework.metric.fidelity import Fidelity +from explaining_framework.metric.robust import Attack +from explaining_framework.metric.sparsity import Sparsity __all__captum = [ "LRP", diff --git a/explaining_framework/metric/base.py b/explaining_framework/metric/base.py index d55250c..13f41dc 100644 --- a/explaining_framework/metric/base.py +++ b/explaining_framework/metric/base.py @@ -41,6 +41,6 @@ class Metric(ABC): """ with torch.no_grad(): - out = self.model(*args, **kwargs)[0] + out = self.model(*args, **kwargs) return out diff --git a/explaining_framework/metric/fidelity.py b/explaining_framework/metric/fidelity.py index 48e8392..187f589 100644 --- a/explaining_framework/metric/fidelity.py +++ b/explaining_framework/metric/fidelity.py @@ -1,12 +1,11 @@ import torch import torch.nn.functional as F +from explaining_framework.metric.base import Metric from torch import Tensor from torch.nn import KLDivLoss, Softmax from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg -from explaining_framework.metric.base import Metric - NUM_CLASS = cfg.share.dim_out @@ -58,23 +57,30 @@ class Fidelity(Metric): self._score_check() inferred_class_initial = torch.argmax(self.s_initial_data, dim=1) inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1) - return torch.mean( - (exp.y == inferred_class_initial).float() - - (exp.y == inferred_class_exp).float() - ).item() + return ( + ( + (exp.y == inferred_class_initial).float() + - (exp.y == inferred_class_exp).float() + ) + .mean() + .item() + ) def _fidelity_minus(self, exp: Explanation) -> float: self._score_check() inferred_class_initial = torch.argmax(self.s_initial_data, dim=1) inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1) - return torch.mean( - (exp.y == inferred_class_initial).float() - - (exp.y == inferred_class_exp).float() - ).item() + return ( + ( + (exp.y == inferred_class_initial).float() + - (exp.y == inferred_class_exp).float() + ) + .mean() + .item() + ) def _fidelity_plus_prob(self, exp: Explanation) -> float: self._score_check() - # one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS) prob_initial = softmax(self.s_initial_data) prob_exp = softmax(self.s_exp_sub_c) @@ -82,9 +88,7 @@ class Fidelity(Metric): prob_initial = prob_initial[torch.arange(size), exp.y] prob_exp = prob_exp[torch.arange(size), exp.y] - return torch.mean( - torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1) - ).item() + return (prob_initial - prob_exp).mean().item() def _fidelity_minus_prob(self, exp: Explanation) -> float: self._score_check() @@ -95,9 +99,7 @@ class Fidelity(Metric): prob_initial = prob_initial[torch.arange(size), exp.y] prob_exp = prob_exp[torch.arange(size), exp.y] - return torch.mean( - torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1) - ).item() + return (prob_initial - prob_exp).mean().item() def _infidelity_KL(self, exp: Explanation) -> float: self._score_check() @@ -191,6 +193,13 @@ class Fidelity(Metric): raise ValueError(f"{name} is not supported") return self.metric + def reset_score(self): + self.exp_sub = None + self.exp_sub_c = None + self.s_exp_sub = None + self.s_exp_sub_c = None + self.s_initial_data = None + def forward(self, exp: Explanation): self.score(exp) return self.metric(exp) diff --git a/explaining_framework/metric/robust.py b/explaining_framework/metric/robust.py index abb99b4..f0795ab 100644 --- a/explaining_framework/metric/robust.py +++ b/explaining_framework/metric/robust.py @@ -3,11 +3,13 @@ import copy import torch import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss +from torch_geometric.data import Batch, Data from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node from explaining_framework.metric.base import Metric +from explaining_framework.utils.io import obj_config_to_str def compute_gradient(model, inp, target, loss): @@ -18,121 +20,6 @@ def compute_gradient(model, inp, target, loss): return torch.autograd.grad(err, inp.x)[0] -class FGSM(Metric): - def __init__( - self, - model: torch.nn.Module, - loss: torch.nn.Module, - lower_bound: float = float("-inf"), - upper_bound: float = float("inf"), - ): - super().__init__(name="fgsm", model=model) - self.model = model - self.loss = loss - self.lower_bound = lower_bound - self.upper_bound = upper_bound - - self.bound = lambda x: torch.clamp( - x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound]) - ) - - self.zero_thresh = 10**-6 - - def forward(self, input, target, epsilon: float) -> Explanation: - input_ = input.clone() - grad = compute_gradient( - model=self.model, inp=input_, target=target, loss=self.loss - ) - grad = self.bound(grad) - input_.x = torch.where( - torch.abs(grad) > self.zero_thresh, - input_.x - epsilon * torch.sign(grad), - input_.x, - ) - return input_ - - def load_metric(self): - pass - - -class PGD(Metric): - def __init__( - self, - model: torch.nn.Module, - loss: torch.nn.Module, - lower_bound: float = float("-inf"), - upper_bound: float = float("inf"), - ): - super().__init__(name="pgd", model=model) - self.model = model - self.loss = loss - self.lower_bound = lower_bound - self.upper_bound = upper_bound - self.bound = lambda x: torch.clamp( - x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound]) - ) - self.zero_thresh = 10**-6 - self.fgsm = FGSM( - model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound - ) - - def forward( - self, - input, - target, - epsilon: float, - radius: float, - step_num: int, - random_start: bool = False, - norm: str = "inf", - ) -> Explanation: - def _clip(inputs: Explanation, outputs: Explanation) -> Explanation: - diff = outputs.x - inputs.x - if norm == "inf": - inputs.x = inputs.x + torch.clamp(diff, -radius, radius) - return inputs - elif norm == "2": - inputs.x = inputs.x + torch.renorm(diff, 2, 0, radius) - return inputs - else: - raise AssertionError("Norm constraint must be L2 or Linf.") - - perturbed_input = input - if random_start: - perturbed_input = self.bound(self._random_point(input.x, radius, norm)) - for _ in range(step_num): - perturbed_input = self.fgsm.forward( - input=perturbed_input, epsilon=epsilon, target=target - ) - perturbed_input = _clip(input, perturbed_input) - perturbed_input.x = self.bound(perturbed_input.x).detach() - return perturbed_input - - def load_metric(self): - pass - - def _random_point( - self, center: torch.Tensor, radius: float, norm: str - ) -> torch.Tensor: - r""" - A helper function that returns a uniform random point within the ball - with the given center and radius. Norm should be either L2 or Linf. - """ - if norm == "2": - u = torch.randn_like(center) - unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size()) - d = torch.numel(center[0]) - r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius - r = r[(...,) + (None,) * (r.dim() - 1)] - x = r * unit_u - return center + x - elif norm == "inf": - x = torch.rand_like(center) * radius * 2 - radius - return center + x - else: - raise AssertionError("Norm constraint must be L2 or Linf.") - - class Attack(Metric): def __init__( self, @@ -152,8 +39,10 @@ class Attack(Metric): "remove_node", "pgd", "fgsm", + "no_attack", ] self.dropout = dropout + self.config = None if loss is None: if cfg.model.loss_fun == "cross_entropy": self.loss = CrossEntropyLoss() @@ -166,10 +55,8 @@ class Attack(Metric): self.load_metric(name) def _gaussian_noise(self, exp) -> Explanation: - x = torch.clone(exp.x) - x = x + torch.randn(*x.shape) exp_ = exp.clone() - exp_.x = x + exp_.x = exp_.x + torch.randn(*exp_.x.shape).to(exp_.x.device) return exp_ def _add_edge(self, exp, p: float) -> Explanation: @@ -203,10 +90,15 @@ class Attack(Metric): def _load_gaussian_noise(self): return lambda exp: self._gaussian_noise(exp) + def _load_no_attack(self): + return lambda exp: exp + def load_metric(self, name): if name in self.authorized_metric: if name == "gaussian_noise": self.metric = self._load_gaussian_noise() + if name == "no_attack": + self.metric = self._load_no_attack() if name == "add_edge": self.metric = self._load_add_edge() if name == "remove_edge": @@ -214,21 +106,24 @@ class Attack(Metric): if name == "remove_node": self.metric = self._load_remove_node() if name == "pgd": - pgd = PGD(model=self.model, loss=self.loss) - self.metric = lambda exp: pgd.forward( - input=exp, - target=exp.y, + pgd = PGD( + model=self.model, + loss=self.loss, epsilon=1, radius=1, step_num=50, random_start=False, norm="inf", ) - if name == "fgsm": - fgsm = FGSM(model=self.model, loss=self.loss) - self.metric = lambda exp: fgsm.forward( - input=exp, target=exp.y, epsilon=1 + self.config = obj_config_to_str(pgd.__dict__) + self.metric = lambda exp: pgd.forward( + input=exp, + target=exp.y, ) + if name == "fgsm": + fgsm = FGSM(model=self.model, loss=self.loss, epsilon=1) + self.config = obj_config_to_str(fgsm.__dict__) + self.metric = lambda exp: fgsm.forward(input=exp, target=exp.y) else: raise ValueError(f"{name} is not supported yet") @@ -237,3 +132,120 @@ class Attack(Metric): def forward(self, exp) -> Explanation: attack = self.metric(exp) return attack + + def get_attacked_prediction(self, data: Data) -> Data: + data_ = data.clone() + data_attacked = self.forward(data_) + pred = self.get_prediction(x=data_.x, edge_index=data_.edge_index) + pred_attacked = self.get_prediction( + x=data_attacked.x, edge_index=data_attacked.edge_index + ) + setattr(data_, "pred", pred) + setattr(data_, "pred_attacked", pred_attacked) + return data_ + + +class FGSM(Metric): + def __init__( + self, + model: torch.nn.Module, + loss: torch.nn.Module, + lower_bound: float = float("-inf"), + upper_bound: float = float("inf"), + epsilon=1, + ): + super().__init__(name="fgsm", model=model) + self.model = model + self.loss = loss + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.epsilon = epsilon + + self.bound = lambda x: torch.clamp( + x, + min=torch.Tensor([lower_bound]).to(x.device), + max=torch.Tensor([upper_bound]).to(x.device), + ).to(x.device) + + self.zero_thresh = 10**-6 + + def forward(self, input, target) -> Explanation: + input_ = input.clone() + grad = compute_gradient( + model=self.model, inp=input_, target=target, loss=self.loss + ) + grad = self.bound(grad) + input_.x = torch.where( + torch.abs(grad) > self.zero_thresh, + input_.x - self.epsilon * torch.sign(grad), + input_.x, + ) + return input_ + + def load_metric(self): + pass + + +class PGD(Metric): + def __init__( + self, + model: torch.nn.Module, + loss: torch.nn.Module, + lower_bound: float = float("-inf"), + upper_bound: float = float("inf"), + epsilon=1, + radius=1, + step_num=50, + random_start=False, + norm="inf", + ): + super().__init__(name="pgd", model=model) + self.model = model + self.loss = loss + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.bound = lambda x: torch.clamp( + x, + min=torch.Tensor([lower_bound]).to(x.device), + max=torch.Tensor([upper_bound]).to(x.device), + ).to(x.device) + + self.zero_thresh = 10**-6 + self.fgsm = FGSM( + model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound + ) + self.epsilon = epsilon + self.radius = radius + self.step_num = step_num + self.random_start = random_start + self.norm = norm + + def forward( + self, + input, + target, + ) -> Explanation: + def _clip(inputs: Explanation, outputs: Explanation) -> Explanation: + diff = outputs.x - inputs.x + if self.norm == "inf": + inputs.x = inputs.x + torch.clamp(diff, -self.radius, self.radius) + return inputs + elif self.norm == "2": + inputs.x = inputs.x + torch.renorm(diff, 2, 0, self.radius) + return inputs + else: + raise AssertionError("Norm constraint must be L2 or Linf.") + + perturbed_input = input + if self.random_start: + perturbed_input = self.bound( + self._random_point(input.x, self.radius, self.norm) + ) + for _ in range(self.step_num): + perturbed_input = self.fgsm.forward(input=perturbed_input, target=target) + perturbed_input = _clip(input, perturbed_input) + perturbed_input.x = self.bound(perturbed_input.x).detach() + return perturbed_input + + def load_metric(self): + pass diff --git a/explaining_framework/metric/sparsity.py b/explaining_framework/metric/sparsity.py index 069564a..a2d2402 100644 --- a/explaining_framework/metric/sparsity.py +++ b/explaining_framework/metric/sparsity.py @@ -20,6 +20,6 @@ class Sparsity(Metric): def forward(self, exp: Explanation) -> float: out = {} for k, v in exp.to_dict().items(): - if "mask" in k and v.dtype == torch.bool: - out[k] = torch.mean(mask.float()).item() + if "mask" in k and torch.all(torch.logical_or(v == 0, v == 1)).item(): + out[k] = torch.mean(v).item() return out diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index d2ea05e..25e00d9 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -124,13 +124,12 @@ class LoadModelInfo(object): model_name = os.path.basename(self.info["xp_dir_path"]) model_seed = self.info["seed"] - epoch = os.path.basename(self.info["ckpt_path"]) model_signature = "-".join( [ f"{name}={val}" for name, val in zip(["name", "seed"], [model_name, model_seed]) ] - + [epoch] + + [self.which] ) return model_signature diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 99707ac..5f47eea 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -14,6 +14,7 @@ from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.model_builder import cfg, create_model from torch_geometric.graphgym.utils.device import auto_select_device from torch_geometric.loader.dataloader import DataLoader +from yacs.config import CfgNode as CN from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg @@ -22,6 +23,7 @@ from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper from explaining_framework.explainers.wrappers.from_graphxai import \ GraphXAIWrapper +from explaining_framework.explainers.wrappers.from_pyg import PYGWrapper from explaining_framework.metric.accuracy import Accuracy from explaining_framework.metric.base import Metric from explaining_framework.metric.fidelity import Fidelity @@ -47,7 +49,7 @@ all__captum = [ "GuidedBackprop", "GuidedGradCam", "InputXGradient", - "IntegratedGradients", + # "IntegratedGradients", "Lime", "Occlusion", "Saliency", @@ -67,6 +69,10 @@ all__graphxai = [ "GraphMASK", "GNNExplainer", ] +all__pyg = [ + # "PGExplainer", + # "GNNExplainer", +] all__own = ["EIXGNN", "SCGNN"] @@ -94,10 +100,11 @@ all_robust = [ "remove_node", "pgd", "fgsm", + "no_attack", ] all_sparsity = ["l0"] -adjust_pattern = "ranp" +adjust_pattern = "ranps" all_adjusts_filters = [ "".join(filters) for i in range(len(adjust_pattern) + 1) @@ -168,9 +175,9 @@ class ExplainingOutline(object): def load_indexes(self): - items = self.explaining_cfg.dataset.items - if isinstance(items, (list, int)): - indexes = items + item = self.explaining_cfg.dataset.item + if isinstance(item, (list, int)): + indexes = item else: indexes = list(range(len(self.dataset))) self.indexes = iter(indexes) @@ -223,7 +230,7 @@ class ExplainingOutline(object): elif self.explaining_cfg.explainer.name == "SCGNN": self.explainer_cfg = copy.copy(scgnn_cfg) else: - self.explainer_cfg = None + self.explainer_cfg = CN() else: if self.explaining_cfg.explainer.name == "EIXGNN": eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) @@ -241,6 +248,7 @@ class ExplainingOutline(object): if self.model is None: raise ValueError("Model ckpt has not been loaded, ckpt file not found") self.model = self.model.eval() + self.model.explain = True def load_dataset(self): if self.cfg is None: @@ -252,19 +260,26 @@ class ExplainingOutline(object): f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." ) self.dataset = create_dataset() - items = self.explaining_cfg.dataset.items - print(items) - print(type(items)) - if isinstance(items, int): - self.dataset = self.dataset[items : items + 1] - elif isinstance(items, list): - self.dataset = self.dataset[items] + item = self.explaining_cfg.dataset.item + if isinstance(item, int): + self.dataset = self.dataset[item : item + 1] + elif isinstance(item, list): + self.dataset = self.dataset[item] def load_dataset_to_dataloader(self, to_iter=True): self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) if to_iter: self.dataset = iter(self.dataset) + def reload_dataset(self): + self.load_dataset() + self.load_indexes() + + def reload_dataloader(self): + self.load_dataset() + self.load_dataset_to_dataloader() + self.load_indexes() + def load_explaining_algorithm(self): self.load_explainer_cfg() if self.model is None: @@ -273,14 +288,16 @@ class ExplainingOutline(object): self.load_dataset() name = self.explaining_cfg.explainer.name - if name in all__captum: - explaining_algorithm = CaptumWrapper(name) - elif name in all__graphxai: + if name in all__graphxai: explaining_algorithm = GraphXAIWrapper( name, in_channels=self.dataset.num_classes, criterion=self.cfg.model.loss_fun, ) + elif name in all__captum: + explaining_algorithm = CaptumWrapper(name) + elif name in all__pyg: + explaining_algorithm = PYGWrapper(name) elif name in all__own: if name == "EIXGNN": explaining_algorithm = EiXGNN( @@ -296,6 +313,7 @@ class ExplainingOutline(object): depth=self.explainer_cfg.depth, interest_map_norm=self.explainer_cfg.interest_map_norm, score_map_norm=self.explainer_cfg.score_map_norm, + target_baseline=self.explainer_cfg.target_baseline, ) elif name is None: explaining_algorithm = None @@ -539,6 +557,7 @@ class ExplainingOutline(object): explanation = _get_explanation(self.explainer, item) else: explanation = _load_explanation(path) + explanation = explanation.to(self.cfg.accelerator) else: explanation = _get_explanation(self.explainer, item) get_pred(self.explainer, explanation) @@ -590,3 +609,14 @@ class ExplainingOutline(object): if item.num_nodes <= 500: stat = self.graphstat(item) write_json(stat, path) + + def get_attack(self, attack: Attack, item: Data, path: str): + if is_exists(path): + if self.explaining_cfg.explainer.force: + data_attack = attack.get_attacked_prediction(item) + else: + data_attack = _load_explanation(path) + else: + data_attack = attack.get_attacked_prediction(item) + _save_explanation(data_attack, path) + return data_attack diff --git a/explaining_framework/utils/explanation/adjust.py b/explaining_framework/utils/explanation/adjust.py index 0cdbe98..5b12336 100644 --- a/explaining_framework/utils/explanation/adjust.py +++ b/explaining_framework/utils/explanation/adjust.py @@ -9,37 +9,46 @@ from torch_geometric.explain.explanation import Explanation class Adjust(object): def __init__( self, - strategy: str = "rpn", + strategy: str = "rpns", ): self.strategy = strategy def forward(self, exp: Explanation) -> Explanation: exp_ = exp.clone() - _store = exp_.to_dict() - for k, v in _store.items(): + for k, v in exp_.items(): if "mask" in k: for f_ in self.strategy: if f_ == "r": - _store[k] = self.relu(v) + exp_.__setattr__(k, self.relu(v)) if f_ == "a": - _store[k] = self.absolute(v) + exp_.__setattr__(k, self.absolute(v)) if f_ == "p": if "edge" in k: pass else: - _store[k] = self.project(v) + exp_.__setattr__(k, self.project(v)) if f_ == "n": - _store[k] = self.normalize(v) + exp_.__setattr__(k, self.normalize(v)) + if f_ == "s": + exp_.__setattr__(k, self.squeeze_(v)) + else: continue return exp_ def relu(self, mask: FloatTensor) -> FloatTensor: - relu = ReLU() + relu = ReLU(inplace=True) mask_ = relu(mask) return mask_ + def squeeze_(self, mask: FloatTensor) -> FloatTensor: + if mask.max() == mask.min(): + return mask + else: + mask_ = (mask - mask.min()).div(mask.max() - mask.min()) + return mask_ + def normalize(self, mask: FloatTensor) -> FloatTensor: norm = torch.norm(mask, p=float("inf")) if norm.item() > 0: diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index 9624562..146b3dc 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -26,22 +26,45 @@ def write_yaml(data: dict, path: str) -> None: data = yaml.dump(data, f) +def dump_cfg(cfg, path): + r""" + Dumps the config to the output directory specified in + :obj:`cfg.out_dir` + Args: + cfg (CfgNode): Configuration node + """ + with open(path, "w") as f: + cfg.dump(stream=f) + + def is_exists(path: str) -> bool: return os.path.exists(path) -def get_obj_config(obj): - config = { - k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool)) - } +def get_dict_config(d: dict): + config = {} + for k, v in d.items(): + if isinstance(v, (int, float, str, bool)): + config[k] = val_check(v) return config +def val_check(v): + if v == float("-inf"): + return "minus_inf" + else: + return v + + def save_obj_config(obj, path) -> None: config = get_obj_config(obj) write_json(config, path) def obj_config_to_str(obj) -> str: - config = get_obj_config(obj) - return "-".join([f"{k}={v}" for k, v in config.items()]) + if isinstance(obj, dict): + config = get_dict_config(obj) + return "-".join([f"{k}={v}" for k, v in config.items()]) + else: + config = get_dict_config(obj.__dict__) + return "-".join([f"{k}={v}" for k, v in config.items()]) diff --git a/main.py b/main.py index 8e25d9b..77f6896 100644 --- a/main.py +++ b/main.py @@ -18,8 +18,9 @@ from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust -from explaining_framework.utils.io import (is_exists, obj_config_to_str, - read_json, write_json, write_yaml) +from explaining_framework.utils.io import (dump_cfg, is_exists, + obj_config_to_str, read_json, + write_json) # inference, time, force, @@ -27,65 +28,100 @@ from explaining_framework.utils.io import (is_exists, obj_config_to_str, if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) - print(outline.explaining_cfg) - - out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature) + out_dir = os.path.join( + outline.explaining_cfg.out_dir, + outline.cfg.dataset.name, + outline.model_signature, + ) makedirs(out_dir) - write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml")) + dump_cfg(outline.cfg, os.path.join(out_dir, "config.yaml")) write_json(outline.model_info, os.path.join(out_dir, "info.json")) explainer_path = os.path.join( out_dir, - outline.explaining_cfg.explainer.name - + "_" - + obj_config_to_str(outline.explaining_algorithm), + outline.explaining_cfg.explainer.name, + obj_config_to_str(outline.explaining_algorithm), ) - makedirs(explainer_path) - write_yaml( - outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest) + dump_cfg( + outline.explainer_cfg, + os.path.join(explainer_path, "explainer_cfg.yaml"), ) - write_yaml( - outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml") + dump_cfg( + outline.explaining_cfg, + os.path.join(explainer_path, explaining_cfg.cfg_dest), ) - specific_explainer_path = os.path.join( - explainer_path, obj_config_to_str(outline.explaining_algorithm) - ) - makedirs(specific_explainer_path) - - raw_path = os.path.join(specific_explainer_path, "raw") - makedirs(raw_path) - item, index = outline.get_item() while not (item is None or index is None): - explanation_path = os.path.join(raw_path, f"{index}.json") - raw_exp = outline.get_explanation(item=item, path=explanation_path) - for adjust in outline.adjusts: - adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}") - makedirs(adjust_path) - exp_adjust_path = os.path.join(adjust_path, f"{index}.json") - exp_adjust = outline.get_adjust( - adjust=adjust, item=raw_exp, path=exp_adjust_path + for attack in outline.attacks: + attack_path = os.path.join( + out_dir, attack.__class__.__name__, obj_config_to_str(attack) ) - for threshold_conf in outline.thresholds_configs: - outline.set_explainer_threshold_config(threshold_conf) - masking_path = os.path.join( - adjust_path, - "-".join([f"{k}={v}" for k, v in threshold_conf.items()]), + makedirs(attack_path) + data_attack_path = os.path.join(attack_path, f"{index}.json") + data_attack = outline.get_attack( + attack=attack, item=item, path=data_attack_path + ) + item, index = outline.get_item() + + outline.reload_dataloader() + makedirs(explainer_path) + + item, index = outline.get_item() + while not (item is None or index is None): + for attack in outline.attacks: + attack_path_ = os.path.join( + explainer_path, attack.__class__.__name__, obj_config_to_str(attack) + ) + makedirs(attack_path_) + data_attack_path_ = os.path.join(attack_path_, f"{index}.json") + attack_data = outline.get_attack( + attack=attack, item=item, path=data_attack_path_ + ) + exp = outline.get_explanation(item=attack_data, path=data_attack_path_) + for adjust in outline.adjusts: + adjust_path = os.path.join( + attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust) ) - makedirs(masking_path) - exp_masked_path = os.path.join(masking_path, f"{index}.json") - exp_masked = outline.get_threshold( - item=exp_adjust, path=exp_masked_path + makedirs(adjust_path) + exp_adjust_path = os.path.join(adjust_path, f"{index}.json") + exp_adjust = outline.get_adjust( + adjust=adjust, item=exp, path=exp_adjust_path ) - for metric in outline.metrics: - metric_path = os.path.join( - masking_path, f"{obj_config_to_str(metric)}" + for threshold_conf in outline.thresholds_configs: + outline.set_explainer_threshold_config(threshold_conf) + masking_path = os.path.join( + adjust_path, + "ThresholdConfig", + obj_config_to_str(threshold_conf), ) - makedirs(metric_path) - metric_path = os.path.join(metric_path, f"{index}.json") - out_metric = outline.get_metric( - metric=metric, item=exp_masked, path=metric_path + makedirs(masking_path) + exp_masked_path = os.path.join(masking_path, f"{index}.json") + exp_masked = outline.get_threshold( + item=exp_adjust, path=exp_masked_path ) + for metric in outline.metrics: + metric_path = os.path.join( + masking_path, + metric.__class__.__name__, + obj_config_to_str(metric), + ) + makedirs(metric_path) + metric_path = os.path.join(metric_path, f"{index}.json") + out_metric = outline.get_metric( + metric=metric, item=exp_masked, path=metric_path + ) + print("#################################") + print("Attack", attack.name) + print( + "ThresholdConfig", + "-".join([f"{k}={v}" for k, v in threshold_conf.items()]), + ) + print("Metric", metric.name) + print("Val", out_metric) + print("Index", index) + print("#################################") + + item, index = outline.get_item()