Renaming file and adding new features

This commit is contained in:
araison 2022-12-29 22:00:39 +01:00
parent 9d4aedbca7
commit 26fa51e2de
7 changed files with 167 additions and 33 deletions

View File

@ -187,22 +187,24 @@ class CaptumWrapper(ExplainerAlgorithm):
raise ValueError(f"{self.name} is not a supported Captum method yet !")
def _parse_attr(self, attr):
for i in range(len(attr)):
attr[i] = attr[i].squeeze()
if self.mask_type == "node":
node_feat_mask = attr[0].squeeze(0)
node_mask = attr[0]
edge_mask = None
if self.mask_type == "edge":
node_feat_mask = None
node_mask = None
edge_mask = attr[0]
if self.mask_type == "node_and_edge":
node_feat_mask = attr[0].squeeze(0)
node_mask = attr[0]
edge_mask = attr[1]
else:
raise ValueError
edge_feat_mask = None
node_mask = None
node_feat_mask = None
return node_mask, edge_mask, node_feat_mask, edge_feat_mask

View File

@ -2,9 +2,12 @@ import copy
import torch
import torch.nn.functional as F
from explaining_framework.metric.base import Metric
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
from troch.nn import CrossEntropyLoss, MSELoss
from explaining_framework.metric.base import Metric
def compute_gradient(model, inp, target, loss):
@ -142,7 +145,15 @@ class Attack(Metric):
"fgsm",
]
self.dropout = dropout
self.loss = loss
if loss is None:
if cfg.model.loss_fun == "cross-entropy":
self.loss = CrossEntropyLoss()
if cfg.model.loss_fun == "mse":
self.loss = MSELoss()
else:
raise ValueError
else:
self.loss = loss
self.load_metric(name)
def _gaussian_noise(self, exp) -> Explanation:
@ -194,7 +205,6 @@ class Attack(Metric):
if name == "remove_node":
self.metric = self._load_remove_node()
if name == "pgd":
print("set LOSS with cfg ")
pgd = PGD(model=self.model, loss=self.loss)
self.metric = lambda exp: pgd.forward(
input=exp,
@ -206,7 +216,6 @@ class Attack(Metric):
norm="inf",
)
if name == "fgsm":
print("set LOSS with cfg ")
fgsm = FGSM(model=self.model, loss=self.loss)
self.metric = lambda exp: fgsm.forward(
input=exp, target=exp.y, epsilon=1

View File

@ -4,7 +4,21 @@ from explaining_framework.metric.base import Metric
class Sparsity(Metric):
def __init__(self, name):
super().__init__(name=name, model=None)
super().__init__(name=name)
self.authorized_metric = ['l0']
self.metric = self.load_metric(name)
def load_metric(self,name):
if name in self.authorized_metric:
if name == 'l0':
metric = lambda x : torch.mean(mask.float()).item()
else:
raise ValueError(f'{name} is not supported yet')
def forward(self, exp:Explanation) -> float:
out = {}
for k,v in exp.to_dict():
if 'mask' in
def forward(self, mask):
return torch.mean(mask.float()).item()

View File

@ -1,6 +1,15 @@
import copy
from typing import Any
from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -14,13 +23,6 @@ from explaining_framework.metric.robust import Attack
from explaining_framework.metric.sparsity import Sparsity
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
_load_ckpt)
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
all__captum = [
"LRP",
@ -54,6 +56,30 @@ all__graphxai = [
all__own = ["EIXGNN", "SCGNN"]
all_fidelity = [
"fidelity_plus",
"fidelity_minus",
"fidelity_plus_prob",
"fidelity_minus_prob",
"infidelity_KL",
]
all_accuracy = [
"precision_score",
"jaccard_score",
"roc_auc_score",
"f1_score",
"accuracy_score",
]
all_robust = [
"gaussian_noise",
"add_edge",
"remove_edge",
"remove_node",
"pgd",
"fgsm",
]
class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str):
@ -95,7 +121,7 @@ class ExplainingOutline(object):
def load_explainer_cfg(self):
if self.explaining_cfg is None:
self.explaining_cfg()
self.load_explaining_cfg()
else:
if self.explaining_cfg.explainer.cfg == "default":
if self.explaining_cfg.explainer.name == "EIXGNN":
@ -127,7 +153,7 @@ class ExplainingOutline(object):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.explaining_cfg()
self.load_explaining_cfg()
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
raise ValueError(
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
@ -167,7 +193,33 @@ class ExplainingOutline(object):
score_map_norm=self.explainer_cfg.score_map_norm,
)
self.explaining_algorithm = explaining_algorithm
print(self.explaining_algorithm.__dict__)
def load_metric(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
if self.explaining_cfg.metrics.type == "all":
if self.explaining_cfg.dataset.name == 'BASHAPES':
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
def load_attack(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
all_rob_metrics = [Attack(name) for name in all_robust]
class FileManager(object):
def __init__(self):
pass
def save(obj: Any, path: str) -> None:
pass
PATH = "config_exp.yaml"

View File

@ -3,14 +3,63 @@ import copy
from torch import FloatTensor
from torch.nn import ReLU
class Adjust(object):
def __init__(
self,
apply_relu: bool = True,
apply_normalize: bool = True,
apply_project: bool = True,
apply_absolute: bool = False,
):
self.apply_relu = apply_relu
self.apply_normalize = apply_normalize
self.apply_project = apply_project
self.apply_absolute = apply_absolute
def relu_mask(explanation: Explanation) -> Explanation:
relu = ReLU()
explanation_store = explanation._store
raw_data = copy.copy(explanation._store)
for k, v in explanation_store.items():
if "mask" in k:
explanation_store[k] = relu(v)
explanation.__setattr__("raw_explanation", raw_data)
explanation.__setattr__("raw_explanation_transform", "relu")
return explanation
if self.apply_absolute and self.apply_relu:
self.apply_relu = False
def forward(self, exp: Explanation) -> Explanation:
exp_ = exp.copy()
_store = exp_.to_dict()
for k, v in _store.items():
if "mask" in k:
if self.apply_relu:
_store[k] = self.relu(v)
elif self.apply_absolute:
_store[k] = self.absolute(v)
elif self.apply_project:
if "edge" in k:
pass
else:
_store[k] = self.project(v)
elif self.apply_normalize:
_store[k] = self.normalize(v)
else:
continue
return exp_
def relu(self, mask: FloatTensor) -> FloatTensor:
relu = ReLU()
mask_ = relu(mask)
return mask_
def normalize(self, mask: FloatTensor) -> FloatTensor:
norm = torch.norm(mask, p="inf")
if norm.item() > 0:
mask_ = mask / norm.item()
return mask_
else:
return mask
def project(self, mask: FloatTensor) -> FloatTensor:
if mask.ndim >= 2:
mask_ = torch.sum(mask, dim=1)
return mask_
else:
return mask
def absolute(self, mask: FloatTensor) -> FloatTensor:
mask_ = torch.abs(mask)
return mask_

View File

@ -12,8 +12,9 @@ def explanation_verification(exp: Explanation) -> bool:
for mask in masks:
is_nan = mask.isnan().any().item()
is_inf = mask.isinf().any().item()
is_const = mask.max()==mask.min()
is_ok = exp.validate()
if is_nan or is_inf or not is_ok:
if is_nan or is_inf or not is_ok or is_const:
is_good = False
return is_good
else:
@ -47,5 +48,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
data = exp.to_dict()
for k, v in data.items():
if "_mask" in k and isinstance(v, torch.FloatTensor):
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
norm =torch.norm(input=data[k], p=p, dim=None).item()
if norm.item()>0:
data[k] = data[k] / norm
return exp

View File

@ -1,5 +1,5 @@
import json
import os
import yaml
@ -23,3 +23,7 @@ def read_yaml(path: str) -> dict:
def write_yaml(data: dict, path: str) -> None:
with open(path, "w") as f:
data = yaml.dump(data, f)
def is_exists(path:str)-> bool:
return os.path.exists(path)