Renaming file and adding new features

This commit is contained in:
araison 2022-12-29 22:00:39 +01:00
parent 9d4aedbca7
commit 26fa51e2de
7 changed files with 167 additions and 33 deletions

View File

@ -187,22 +187,24 @@ class CaptumWrapper(ExplainerAlgorithm):
raise ValueError(f"{self.name} is not a supported Captum method yet !") raise ValueError(f"{self.name} is not a supported Captum method yet !")
def _parse_attr(self, attr): def _parse_attr(self, attr):
for i in range(len(attr)):
attr[i] = attr[i].squeeze()
if self.mask_type == "node": if self.mask_type == "node":
node_feat_mask = attr[0].squeeze(0) node_mask = attr[0]
edge_mask = None edge_mask = None
if self.mask_type == "edge": if self.mask_type == "edge":
node_feat_mask = None node_mask = None
edge_mask = attr[0] edge_mask = attr[0]
if self.mask_type == "node_and_edge": if self.mask_type == "node_and_edge":
node_feat_mask = attr[0].squeeze(0) node_mask = attr[0]
edge_mask = attr[1] edge_mask = attr[1]
else: else:
raise ValueError raise ValueError
edge_feat_mask = None edge_feat_mask = None
node_mask = None node_feat_mask = None
return node_mask, edge_mask, node_feat_mask, edge_feat_mask return node_mask, edge_mask, node_feat_mask, edge_feat_mask

View File

@ -2,9 +2,12 @@ import copy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from explaining_framework.metric.base import Metric
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
from troch.nn import CrossEntropyLoss, MSELoss
from explaining_framework.metric.base import Metric
def compute_gradient(model, inp, target, loss): def compute_gradient(model, inp, target, loss):
@ -142,7 +145,15 @@ class Attack(Metric):
"fgsm", "fgsm",
] ]
self.dropout = dropout self.dropout = dropout
self.loss = loss if loss is None:
if cfg.model.loss_fun == "cross-entropy":
self.loss = CrossEntropyLoss()
if cfg.model.loss_fun == "mse":
self.loss = MSELoss()
else:
raise ValueError
else:
self.loss = loss
self.load_metric(name) self.load_metric(name)
def _gaussian_noise(self, exp) -> Explanation: def _gaussian_noise(self, exp) -> Explanation:
@ -194,7 +205,6 @@ class Attack(Metric):
if name == "remove_node": if name == "remove_node":
self.metric = self._load_remove_node() self.metric = self._load_remove_node()
if name == "pgd": if name == "pgd":
print("set LOSS with cfg ")
pgd = PGD(model=self.model, loss=self.loss) pgd = PGD(model=self.model, loss=self.loss)
self.metric = lambda exp: pgd.forward( self.metric = lambda exp: pgd.forward(
input=exp, input=exp,
@ -206,7 +216,6 @@ class Attack(Metric):
norm="inf", norm="inf",
) )
if name == "fgsm": if name == "fgsm":
print("set LOSS with cfg ")
fgsm = FGSM(model=self.model, loss=self.loss) fgsm = FGSM(model=self.model, loss=self.loss)
self.metric = lambda exp: fgsm.forward( self.metric = lambda exp: fgsm.forward(
input=exp, target=exp.y, epsilon=1 input=exp, target=exp.y, epsilon=1

View File

@ -4,7 +4,21 @@ from explaining_framework.metric.base import Metric
class Sparsity(Metric): class Sparsity(Metric):
def __init__(self, name): def __init__(self, name):
super().__init__(name=name, model=None) super().__init__(name=name)
self.authorized_metric = ['l0']
self.metric = self.load_metric(name)
def load_metric(self,name):
if name in self.authorized_metric:
if name == 'l0':
metric = lambda x : torch.mean(mask.float()).item()
else:
raise ValueError(f'{name} is not supported yet')
def forward(self, exp:Explanation) -> float:
out = {}
for k,v in exp.to_dict():
if 'mask' in
def forward(self, mask):
return torch.mean(mask.float()).item() return torch.mean(mask.float()).item()

View File

@ -1,6 +1,15 @@
import copy import copy
from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from explaining_framework.config.explainer_config.eixgnn_config import \ from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -14,13 +23,6 @@ from explaining_framework.metric.robust import Attack
from explaining_framework.metric.sparsity import Sparsity from explaining_framework.metric.sparsity import Sparsity
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
_load_ckpt) _load_ckpt)
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
all__captum = [ all__captum = [
"LRP", "LRP",
@ -54,6 +56,30 @@ all__graphxai = [
all__own = ["EIXGNN", "SCGNN"] all__own = ["EIXGNN", "SCGNN"]
all_fidelity = [
"fidelity_plus",
"fidelity_minus",
"fidelity_plus_prob",
"fidelity_minus_prob",
"infidelity_KL",
]
all_accuracy = [
"precision_score",
"jaccard_score",
"roc_auc_score",
"f1_score",
"accuracy_score",
]
all_robust = [
"gaussian_noise",
"add_edge",
"remove_edge",
"remove_node",
"pgd",
"fgsm",
]
class ExplainingOutline(object): class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str): def __init__(self, explaining_cfg_path: str):
@ -95,7 +121,7 @@ class ExplainingOutline(object):
def load_explainer_cfg(self): def load_explainer_cfg(self):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.explaining_cfg() self.load_explaining_cfg()
else: else:
if self.explaining_cfg.explainer.cfg == "default": if self.explaining_cfg.explainer.cfg == "default":
if self.explaining_cfg.explainer.name == "EIXGNN": if self.explaining_cfg.explainer.name == "EIXGNN":
@ -127,7 +153,7 @@ class ExplainingOutline(object):
if self.cfg is None: if self.cfg is None:
self.load_cfg() self.load_cfg()
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.explaining_cfg() self.load_explaining_cfg()
if self.explaining_cfg.dataset.name != self.cfg.dataset.name: if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
raise ValueError( raise ValueError(
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
@ -167,7 +193,33 @@ class ExplainingOutline(object):
score_map_norm=self.explainer_cfg.score_map_norm, score_map_norm=self.explainer_cfg.score_map_norm,
) )
self.explaining_algorithm = explaining_algorithm self.explaining_algorithm = explaining_algorithm
print(self.explaining_algorithm.__dict__)
def load_metric(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
if self.explaining_cfg.metrics.type == "all":
if self.explaining_cfg.dataset.name == 'BASHAPES':
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
def load_attack(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
all_rob_metrics = [Attack(name) for name in all_robust]
class FileManager(object):
def __init__(self):
pass
def save(obj: Any, path: str) -> None:
pass
PATH = "config_exp.yaml" PATH = "config_exp.yaml"

View File

@ -3,14 +3,63 @@ import copy
from torch import FloatTensor from torch import FloatTensor
from torch.nn import ReLU from torch.nn import ReLU
class Adjust(object):
def __init__(
self,
apply_relu: bool = True,
apply_normalize: bool = True,
apply_project: bool = True,
apply_absolute: bool = False,
):
self.apply_relu = apply_relu
self.apply_normalize = apply_normalize
self.apply_project = apply_project
self.apply_absolute = apply_absolute
def relu_mask(explanation: Explanation) -> Explanation: if self.apply_absolute and self.apply_relu:
relu = ReLU() self.apply_relu = False
explanation_store = explanation._store
raw_data = copy.copy(explanation._store) def forward(self, exp: Explanation) -> Explanation:
for k, v in explanation_store.items(): exp_ = exp.copy()
if "mask" in k: _store = exp_.to_dict()
explanation_store[k] = relu(v) for k, v in _store.items():
explanation.__setattr__("raw_explanation", raw_data) if "mask" in k:
explanation.__setattr__("raw_explanation_transform", "relu") if self.apply_relu:
return explanation _store[k] = self.relu(v)
elif self.apply_absolute:
_store[k] = self.absolute(v)
elif self.apply_project:
if "edge" in k:
pass
else:
_store[k] = self.project(v)
elif self.apply_normalize:
_store[k] = self.normalize(v)
else:
continue
return exp_
def relu(self, mask: FloatTensor) -> FloatTensor:
relu = ReLU()
mask_ = relu(mask)
return mask_
def normalize(self, mask: FloatTensor) -> FloatTensor:
norm = torch.norm(mask, p="inf")
if norm.item() > 0:
mask_ = mask / norm.item()
return mask_
else:
return mask
def project(self, mask: FloatTensor) -> FloatTensor:
if mask.ndim >= 2:
mask_ = torch.sum(mask, dim=1)
return mask_
else:
return mask
def absolute(self, mask: FloatTensor) -> FloatTensor:
mask_ = torch.abs(mask)
return mask_

View File

@ -12,8 +12,9 @@ def explanation_verification(exp: Explanation) -> bool:
for mask in masks: for mask in masks:
is_nan = mask.isnan().any().item() is_nan = mask.isnan().any().item()
is_inf = mask.isinf().any().item() is_inf = mask.isinf().any().item()
is_const = mask.max()==mask.min()
is_ok = exp.validate() is_ok = exp.validate()
if is_nan or is_inf or not is_ok: if is_nan or is_inf or not is_ok or is_const:
is_good = False is_good = False
return is_good return is_good
else: else:
@ -47,5 +48,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
data = exp.to_dict() data = exp.to_dict()
for k, v in data.items(): for k, v in data.items():
if "_mask" in k and isinstance(v, torch.FloatTensor): if "_mask" in k and isinstance(v, torch.FloatTensor):
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item() norm =torch.norm(input=data[k], p=p, dim=None).item()
if norm.item()>0:
data[k] = data[k] / norm
return exp return exp

View File

@ -1,5 +1,5 @@
import json import json
import os
import yaml import yaml
@ -23,3 +23,7 @@ def read_yaml(path: str) -> dict:
def write_yaml(data: dict, path: str) -> None: def write_yaml(data: dict, path: str) -> None:
with open(path, "w") as f: with open(path, "w") as f:
data = yaml.dump(data, f) data = yaml.dump(data, f)
def is_exists(path:str)-> bool:
return os.path.exists(path)