Renaming file and adding new features
This commit is contained in:
parent
9d4aedbca7
commit
26fa51e2de
@ -187,22 +187,24 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
||||
|
||||
def _parse_attr(self, attr):
|
||||
for i in range(len(attr)):
|
||||
attr[i] = attr[i].squeeze()
|
||||
if self.mask_type == "node":
|
||||
node_feat_mask = attr[0].squeeze(0)
|
||||
node_mask = attr[0]
|
||||
edge_mask = None
|
||||
|
||||
if self.mask_type == "edge":
|
||||
node_feat_mask = None
|
||||
node_mask = None
|
||||
edge_mask = attr[0]
|
||||
|
||||
if self.mask_type == "node_and_edge":
|
||||
node_feat_mask = attr[0].squeeze(0)
|
||||
node_mask = attr[0]
|
||||
edge_mask = attr[1]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
edge_feat_mask = None
|
||||
node_mask = None
|
||||
node_feat_mask = None
|
||||
|
||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||
|
||||
|
@ -2,9 +2,12 @@ import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from explaining_framework.metric.base import Metric
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
||||
from troch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from explaining_framework.metric.base import Metric
|
||||
|
||||
|
||||
def compute_gradient(model, inp, target, loss):
|
||||
@ -142,7 +145,15 @@ class Attack(Metric):
|
||||
"fgsm",
|
||||
]
|
||||
self.dropout = dropout
|
||||
self.loss = loss
|
||||
if loss is None:
|
||||
if cfg.model.loss_fun == "cross-entropy":
|
||||
self.loss = CrossEntropyLoss()
|
||||
if cfg.model.loss_fun == "mse":
|
||||
self.loss = MSELoss()
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
self.loss = loss
|
||||
self.load_metric(name)
|
||||
|
||||
def _gaussian_noise(self, exp) -> Explanation:
|
||||
@ -194,7 +205,6 @@ class Attack(Metric):
|
||||
if name == "remove_node":
|
||||
self.metric = self._load_remove_node()
|
||||
if name == "pgd":
|
||||
print("set LOSS with cfg ")
|
||||
pgd = PGD(model=self.model, loss=self.loss)
|
||||
self.metric = lambda exp: pgd.forward(
|
||||
input=exp,
|
||||
@ -206,7 +216,6 @@ class Attack(Metric):
|
||||
norm="inf",
|
||||
)
|
||||
if name == "fgsm":
|
||||
print("set LOSS with cfg ")
|
||||
fgsm = FGSM(model=self.model, loss=self.loss)
|
||||
self.metric = lambda exp: fgsm.forward(
|
||||
input=exp, target=exp.y, epsilon=1
|
||||
|
@ -4,7 +4,21 @@ from explaining_framework.metric.base import Metric
|
||||
|
||||
class Sparsity(Metric):
|
||||
def __init__(self, name):
|
||||
super().__init__(name=name, model=None)
|
||||
super().__init__(name=name)
|
||||
self.authorized_metric = ['l0']
|
||||
self.metric = self.load_metric(name)
|
||||
|
||||
def load_metric(self,name):
|
||||
if name in self.authorized_metric:
|
||||
if name == 'l0':
|
||||
metric = lambda x : torch.mean(mask.float()).item()
|
||||
else:
|
||||
raise ValueError(f'{name} is not supported yet')
|
||||
|
||||
|
||||
def forward(self, exp:Explanation) -> float:
|
||||
out = {}
|
||||
for k,v in exp.to_dict():
|
||||
if 'mask' in
|
||||
|
||||
def forward(self, mask):
|
||||
return torch.mean(mask.float()).item()
|
||||
|
@ -1,6 +1,15 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||
@ -14,13 +23,6 @@ from explaining_framework.metric.robust import Attack
|
||||
from explaining_framework.metric.sparsity import Sparsity
|
||||
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
||||
_load_ckpt)
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
|
||||
all__captum = [
|
||||
"LRP",
|
||||
@ -54,6 +56,30 @@ all__graphxai = [
|
||||
|
||||
all__own = ["EIXGNN", "SCGNN"]
|
||||
|
||||
all_fidelity = [
|
||||
"fidelity_plus",
|
||||
"fidelity_minus",
|
||||
"fidelity_plus_prob",
|
||||
"fidelity_minus_prob",
|
||||
"infidelity_KL",
|
||||
]
|
||||
all_accuracy = [
|
||||
"precision_score",
|
||||
"jaccard_score",
|
||||
"roc_auc_score",
|
||||
"f1_score",
|
||||
"accuracy_score",
|
||||
]
|
||||
|
||||
all_robust = [
|
||||
"gaussian_noise",
|
||||
"add_edge",
|
||||
"remove_edge",
|
||||
"remove_node",
|
||||
"pgd",
|
||||
"fgsm",
|
||||
]
|
||||
|
||||
|
||||
class ExplainingOutline(object):
|
||||
def __init__(self, explaining_cfg_path: str):
|
||||
@ -95,7 +121,7 @@ class ExplainingOutline(object):
|
||||
|
||||
def load_explainer_cfg(self):
|
||||
if self.explaining_cfg is None:
|
||||
self.explaining_cfg()
|
||||
self.load_explaining_cfg()
|
||||
else:
|
||||
if self.explaining_cfg.explainer.cfg == "default":
|
||||
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||
@ -127,7 +153,7 @@ class ExplainingOutline(object):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
if self.explaining_cfg is None:
|
||||
self.explaining_cfg()
|
||||
self.load_explaining_cfg()
|
||||
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
|
||||
raise ValueError(
|
||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||
@ -167,7 +193,33 @@ class ExplainingOutline(object):
|
||||
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||
)
|
||||
self.explaining_algorithm = explaining_algorithm
|
||||
print(self.explaining_algorithm.__dict__)
|
||||
|
||||
def load_metric(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
|
||||
if self.explaining_cfg.metrics.type == "all":
|
||||
if self.explaining_cfg.dataset.name == 'BASHAPES':
|
||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||
|
||||
def load_attack(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||
|
||||
|
||||
class FileManager(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def save(obj: Any, path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
PATH = "config_exp.yaml"
|
@ -3,14 +3,63 @@ import copy
|
||||
from torch import FloatTensor
|
||||
from torch.nn import ReLU
|
||||
|
||||
class Adjust(object):
|
||||
def __init__(
|
||||
self,
|
||||
apply_relu: bool = True,
|
||||
apply_normalize: bool = True,
|
||||
apply_project: bool = True,
|
||||
apply_absolute: bool = False,
|
||||
):
|
||||
self.apply_relu = apply_relu
|
||||
self.apply_normalize = apply_normalize
|
||||
self.apply_project = apply_project
|
||||
self.apply_absolute = apply_absolute
|
||||
|
||||
def relu_mask(explanation: Explanation) -> Explanation:
|
||||
relu = ReLU()
|
||||
explanation_store = explanation._store
|
||||
raw_data = copy.copy(explanation._store)
|
||||
for k, v in explanation_store.items():
|
||||
if "mask" in k:
|
||||
explanation_store[k] = relu(v)
|
||||
explanation.__setattr__("raw_explanation", raw_data)
|
||||
explanation.__setattr__("raw_explanation_transform", "relu")
|
||||
return explanation
|
||||
if self.apply_absolute and self.apply_relu:
|
||||
self.apply_relu = False
|
||||
|
||||
def forward(self, exp: Explanation) -> Explanation:
|
||||
exp_ = exp.copy()
|
||||
_store = exp_.to_dict()
|
||||
for k, v in _store.items():
|
||||
if "mask" in k:
|
||||
if self.apply_relu:
|
||||
_store[k] = self.relu(v)
|
||||
elif self.apply_absolute:
|
||||
_store[k] = self.absolute(v)
|
||||
elif self.apply_project:
|
||||
if "edge" in k:
|
||||
pass
|
||||
else:
|
||||
_store[k] = self.project(v)
|
||||
elif self.apply_normalize:
|
||||
_store[k] = self.normalize(v)
|
||||
else:
|
||||
continue
|
||||
|
||||
return exp_
|
||||
|
||||
def relu(self, mask: FloatTensor) -> FloatTensor:
|
||||
relu = ReLU()
|
||||
mask_ = relu(mask)
|
||||
return mask_
|
||||
|
||||
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
||||
norm = torch.norm(mask, p="inf")
|
||||
if norm.item() > 0:
|
||||
mask_ = mask / norm.item()
|
||||
return mask_
|
||||
else:
|
||||
return mask
|
||||
|
||||
def project(self, mask: FloatTensor) -> FloatTensor:
|
||||
if mask.ndim >= 2:
|
||||
mask_ = torch.sum(mask, dim=1)
|
||||
return mask_
|
||||
else:
|
||||
return mask
|
||||
|
||||
def absolute(self, mask: FloatTensor) -> FloatTensor:
|
||||
mask_ = torch.abs(mask)
|
||||
return mask_
|
||||
|
@ -12,8 +12,9 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||
for mask in masks:
|
||||
is_nan = mask.isnan().any().item()
|
||||
is_inf = mask.isinf().any().item()
|
||||
is_const = mask.max()==mask.min()
|
||||
is_ok = exp.validate()
|
||||
if is_nan or is_inf or not is_ok:
|
||||
if is_nan or is_inf or not is_ok or is_const:
|
||||
is_good = False
|
||||
return is_good
|
||||
else:
|
||||
@ -47,5 +48,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
|
||||
data = exp.to_dict()
|
||||
for k, v in data.items():
|
||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
|
||||
norm =torch.norm(input=data[k], p=p, dim=None).item()
|
||||
if norm.item()>0:
|
||||
data[k] = data[k] / norm
|
||||
|
||||
return exp
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
@ -23,3 +23,7 @@ def read_yaml(path: str) -> dict:
|
||||
def write_yaml(data: dict, path: str) -> None:
|
||||
with open(path, "w") as f:
|
||||
data = yaml.dump(data, f)
|
||||
|
||||
def is_exists(path:str)-> bool:
|
||||
return os.path.exists(path)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user