Renaming file and adding new features
This commit is contained in:
parent
9d4aedbca7
commit
26fa51e2de
|
@ -187,22 +187,24 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
||||||
|
|
||||||
def _parse_attr(self, attr):
|
def _parse_attr(self, attr):
|
||||||
|
for i in range(len(attr)):
|
||||||
|
attr[i] = attr[i].squeeze()
|
||||||
if self.mask_type == "node":
|
if self.mask_type == "node":
|
||||||
node_feat_mask = attr[0].squeeze(0)
|
node_mask = attr[0]
|
||||||
edge_mask = None
|
edge_mask = None
|
||||||
|
|
||||||
if self.mask_type == "edge":
|
if self.mask_type == "edge":
|
||||||
node_feat_mask = None
|
node_mask = None
|
||||||
edge_mask = attr[0]
|
edge_mask = attr[0]
|
||||||
|
|
||||||
if self.mask_type == "node_and_edge":
|
if self.mask_type == "node_and_edge":
|
||||||
node_feat_mask = attr[0].squeeze(0)
|
node_mask = attr[0]
|
||||||
edge_mask = attr[1]
|
edge_mask = attr[1]
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
edge_feat_mask = None
|
edge_feat_mask = None
|
||||||
node_mask = None
|
node_feat_mask = None
|
||||||
|
|
||||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,12 @@ import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from explaining_framework.metric.base import Metric
|
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
||||||
|
from troch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
|
from explaining_framework.metric.base import Metric
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient(model, inp, target, loss):
|
def compute_gradient(model, inp, target, loss):
|
||||||
|
@ -142,7 +145,15 @@ class Attack(Metric):
|
||||||
"fgsm",
|
"fgsm",
|
||||||
]
|
]
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.loss = loss
|
if loss is None:
|
||||||
|
if cfg.model.loss_fun == "cross-entropy":
|
||||||
|
self.loss = CrossEntropyLoss()
|
||||||
|
if cfg.model.loss_fun == "mse":
|
||||||
|
self.loss = MSELoss()
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
self.loss = loss
|
||||||
self.load_metric(name)
|
self.load_metric(name)
|
||||||
|
|
||||||
def _gaussian_noise(self, exp) -> Explanation:
|
def _gaussian_noise(self, exp) -> Explanation:
|
||||||
|
@ -194,7 +205,6 @@ class Attack(Metric):
|
||||||
if name == "remove_node":
|
if name == "remove_node":
|
||||||
self.metric = self._load_remove_node()
|
self.metric = self._load_remove_node()
|
||||||
if name == "pgd":
|
if name == "pgd":
|
||||||
print("set LOSS with cfg ")
|
|
||||||
pgd = PGD(model=self.model, loss=self.loss)
|
pgd = PGD(model=self.model, loss=self.loss)
|
||||||
self.metric = lambda exp: pgd.forward(
|
self.metric = lambda exp: pgd.forward(
|
||||||
input=exp,
|
input=exp,
|
||||||
|
@ -206,7 +216,6 @@ class Attack(Metric):
|
||||||
norm="inf",
|
norm="inf",
|
||||||
)
|
)
|
||||||
if name == "fgsm":
|
if name == "fgsm":
|
||||||
print("set LOSS with cfg ")
|
|
||||||
fgsm = FGSM(model=self.model, loss=self.loss)
|
fgsm = FGSM(model=self.model, loss=self.loss)
|
||||||
self.metric = lambda exp: fgsm.forward(
|
self.metric = lambda exp: fgsm.forward(
|
||||||
input=exp, target=exp.y, epsilon=1
|
input=exp, target=exp.y, epsilon=1
|
||||||
|
|
|
@ -4,7 +4,21 @@ from explaining_framework.metric.base import Metric
|
||||||
|
|
||||||
class Sparsity(Metric):
|
class Sparsity(Metric):
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
super().__init__(name=name, model=None)
|
super().__init__(name=name)
|
||||||
|
self.authorized_metric = ['l0']
|
||||||
|
self.metric = self.load_metric(name)
|
||||||
|
|
||||||
|
def load_metric(self,name):
|
||||||
|
if name in self.authorized_metric:
|
||||||
|
if name == 'l0':
|
||||||
|
metric = lambda x : torch.mean(mask.float()).item()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{name} is not supported yet')
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, exp:Explanation) -> float:
|
||||||
|
out = {}
|
||||||
|
for k,v in exp.to_dict():
|
||||||
|
if 'mask' in
|
||||||
|
|
||||||
def forward(self, mask):
|
|
||||||
return torch.mean(mask.float()).item()
|
return torch.mean(mask.float()).item()
|
||||||
|
|
|
@ -1,6 +1,15 @@
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -14,13 +23,6 @@ from explaining_framework.metric.robust import Attack
|
||||||
from explaining_framework.metric.sparsity import Sparsity
|
from explaining_framework.metric.sparsity import Sparsity
|
||||||
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
||||||
_load_ckpt)
|
_load_ckpt)
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -54,6 +56,30 @@ all__graphxai = [
|
||||||
|
|
||||||
all__own = ["EIXGNN", "SCGNN"]
|
all__own = ["EIXGNN", "SCGNN"]
|
||||||
|
|
||||||
|
all_fidelity = [
|
||||||
|
"fidelity_plus",
|
||||||
|
"fidelity_minus",
|
||||||
|
"fidelity_plus_prob",
|
||||||
|
"fidelity_minus_prob",
|
||||||
|
"infidelity_KL",
|
||||||
|
]
|
||||||
|
all_accuracy = [
|
||||||
|
"precision_score",
|
||||||
|
"jaccard_score",
|
||||||
|
"roc_auc_score",
|
||||||
|
"f1_score",
|
||||||
|
"accuracy_score",
|
||||||
|
]
|
||||||
|
|
||||||
|
all_robust = [
|
||||||
|
"gaussian_noise",
|
||||||
|
"add_edge",
|
||||||
|
"remove_edge",
|
||||||
|
"remove_node",
|
||||||
|
"pgd",
|
||||||
|
"fgsm",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ExplainingOutline(object):
|
class ExplainingOutline(object):
|
||||||
def __init__(self, explaining_cfg_path: str):
|
def __init__(self, explaining_cfg_path: str):
|
||||||
|
@ -95,7 +121,7 @@ class ExplainingOutline(object):
|
||||||
|
|
||||||
def load_explainer_cfg(self):
|
def load_explainer_cfg(self):
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
else:
|
else:
|
||||||
if self.explaining_cfg.explainer.cfg == "default":
|
if self.explaining_cfg.explainer.cfg == "default":
|
||||||
if self.explaining_cfg.explainer.name == "EIXGNN":
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||||
|
@ -127,7 +153,7 @@ class ExplainingOutline(object):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
|
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||||
|
@ -167,7 +193,33 @@ class ExplainingOutline(object):
|
||||||
score_map_norm=self.explainer_cfg.score_map_norm,
|
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||||
)
|
)
|
||||||
self.explaining_algorithm = explaining_algorithm
|
self.explaining_algorithm = explaining_algorithm
|
||||||
print(self.explaining_algorithm.__dict__)
|
|
||||||
|
def load_metric(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
|
||||||
|
if self.explaining_cfg.metrics.type == "all":
|
||||||
|
if self.explaining_cfg.dataset.name == 'BASHAPES':
|
||||||
|
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||||
|
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
||||||
|
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||||
|
|
||||||
|
def load_attack(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||||
|
|
||||||
|
|
||||||
|
class FileManager(object):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save(obj: Any, path: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
PATH = "config_exp.yaml"
|
PATH = "config_exp.yaml"
|
|
@ -3,14 +3,63 @@ import copy
|
||||||
from torch import FloatTensor
|
from torch import FloatTensor
|
||||||
from torch.nn import ReLU
|
from torch.nn import ReLU
|
||||||
|
|
||||||
|
class Adjust(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
apply_relu: bool = True,
|
||||||
|
apply_normalize: bool = True,
|
||||||
|
apply_project: bool = True,
|
||||||
|
apply_absolute: bool = False,
|
||||||
|
):
|
||||||
|
self.apply_relu = apply_relu
|
||||||
|
self.apply_normalize = apply_normalize
|
||||||
|
self.apply_project = apply_project
|
||||||
|
self.apply_absolute = apply_absolute
|
||||||
|
|
||||||
def relu_mask(explanation: Explanation) -> Explanation:
|
if self.apply_absolute and self.apply_relu:
|
||||||
relu = ReLU()
|
self.apply_relu = False
|
||||||
explanation_store = explanation._store
|
|
||||||
raw_data = copy.copy(explanation._store)
|
def forward(self, exp: Explanation) -> Explanation:
|
||||||
for k, v in explanation_store.items():
|
exp_ = exp.copy()
|
||||||
if "mask" in k:
|
_store = exp_.to_dict()
|
||||||
explanation_store[k] = relu(v)
|
for k, v in _store.items():
|
||||||
explanation.__setattr__("raw_explanation", raw_data)
|
if "mask" in k:
|
||||||
explanation.__setattr__("raw_explanation_transform", "relu")
|
if self.apply_relu:
|
||||||
return explanation
|
_store[k] = self.relu(v)
|
||||||
|
elif self.apply_absolute:
|
||||||
|
_store[k] = self.absolute(v)
|
||||||
|
elif self.apply_project:
|
||||||
|
if "edge" in k:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_store[k] = self.project(v)
|
||||||
|
elif self.apply_normalize:
|
||||||
|
_store[k] = self.normalize(v)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return exp_
|
||||||
|
|
||||||
|
def relu(self, mask: FloatTensor) -> FloatTensor:
|
||||||
|
relu = ReLU()
|
||||||
|
mask_ = relu(mask)
|
||||||
|
return mask_
|
||||||
|
|
||||||
|
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
||||||
|
norm = torch.norm(mask, p="inf")
|
||||||
|
if norm.item() > 0:
|
||||||
|
mask_ = mask / norm.item()
|
||||||
|
return mask_
|
||||||
|
else:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def project(self, mask: FloatTensor) -> FloatTensor:
|
||||||
|
if mask.ndim >= 2:
|
||||||
|
mask_ = torch.sum(mask, dim=1)
|
||||||
|
return mask_
|
||||||
|
else:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def absolute(self, mask: FloatTensor) -> FloatTensor:
|
||||||
|
mask_ = torch.abs(mask)
|
||||||
|
return mask_
|
||||||
|
|
|
@ -12,8 +12,9 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||||
for mask in masks:
|
for mask in masks:
|
||||||
is_nan = mask.isnan().any().item()
|
is_nan = mask.isnan().any().item()
|
||||||
is_inf = mask.isinf().any().item()
|
is_inf = mask.isinf().any().item()
|
||||||
|
is_const = mask.max()==mask.min()
|
||||||
is_ok = exp.validate()
|
is_ok = exp.validate()
|
||||||
if is_nan or is_inf or not is_ok:
|
if is_nan or is_inf or not is_ok or is_const:
|
||||||
is_good = False
|
is_good = False
|
||||||
return is_good
|
return is_good
|
||||||
else:
|
else:
|
||||||
|
@ -47,5 +48,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
|
||||||
data = exp.to_dict()
|
data = exp.to_dict()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||||
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
|
norm =torch.norm(input=data[k], p=p, dim=None).item()
|
||||||
|
if norm.item()>0:
|
||||||
|
data[k] = data[k] / norm
|
||||||
|
|
||||||
return exp
|
return exp
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,3 +23,7 @@ def read_yaml(path: str) -> dict:
|
||||||
def write_yaml(data: dict, path: str) -> None:
|
def write_yaml(data: dict, path: str) -> None:
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
data = yaml.dump(data, f)
|
data = yaml.dump(data, f)
|
||||||
|
|
||||||
|
def is_exists(path:str)-> bool:
|
||||||
|
return os.path.exists(path)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue