Fixing many bugs

This commit is contained in:
araison 2022-12-17 21:08:37 +01:00
parent 03b185c142
commit 55570a26e5
6 changed files with 206 additions and 112 deletions

View file

@ -2,10 +2,10 @@ import traceback
import torch import torch
import torch.nn as nn import torch.nn as nn
from explaining_framework.explaining_framework.metric.accuracy import Accuracy from explaining_framework.metric.accuracy import Accuracy
from explaining_framework.explaining_framework.metric.fidelity import Fidelity from explaining_framework.metric.fidelity import Fidelity
from explaining_framework.explaining_framework.metric.robust import Attack from explaining_framework.metric.robust import Attack
from explaining_framework.explaining_framework.metric.sparsity import Sparsity from explaining_framework.metric.sparsity import Sparsity
from torch_geometric.data import Batch, Data from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer from torch_geometric.explain import Explainer
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
@ -113,26 +113,35 @@ for kind in ["graph"]:
index=int(target), index=int(target),
target=batch.y, target=batch.y,
) )
print(explanation.__dict__)
# explanation.__setattr__( # explanation.__setattr__(
# "model_prediction", explainer.get_prediction(x, edge_index) # "model_prediction", explainer.get_prediction(x, edge_index)
# ) # )
explanation_threshold = explanation._apply_mask( explanation_threshold = explanation._apply_masks(
node_mask=explanation.node_mask, edge_mask=explanation.edge_mask node_mask=torch.ones_like(explanation.node_mask).bool()
) )
print(explanation_threshold.__dict__)
for f_name in [ for f_name in [
"precision_score", "gaussian_noise",
"precision_score", "add_edge",
"jaccard_score", "remove_edge",
"roc_auc_score", "remove_node",
"f1_score", "pgd",
"accuracy_score", "fgsm",
]: ]:
acc = Accuracy(f_name) print(f_name)
gt = torch.ones_like(x) / 2 acc = Attack(name=f_name, model=model, loss=loss)
out = acc.forward(mask=explanation_threshold.node_mask, target=gt) # gt = torch.ones_like(explanation_threshold.node_mask) / 2
# mask = explanation_threshold.node_mask.bool()
# target = (1 - gt).bool()
# target[1] = False
# print(mask, target)
out = acc.forward(explanation)
print(out) print(out)
except Exception as e: except Exception as e:
traceback.print_exc()
# print(str(e)) # print(str(e))
pass pass

View file

@ -1,14 +1,12 @@
import sklearn.metrics import sklearn.metrics
import torch import torch
from explaining_framework.metric.base import Metric
from base import Metric
class Accuracy(Metric): class Accuracy(Metric):
def __init__(name: str): def __init__(self, name: str):
super().__init__(name=name, model=None) super().__init__(name=name, model=None)
self.authorized_metric = [ self.authorized_metric = [
"precision_score",
"precision_score", "precision_score",
"jaccard_score", "jaccard_score",
"roc_auc_score", "roc_auc_score",
@ -16,14 +14,25 @@ class Accuracy(Metric):
"accuracy_score", "accuracy_score",
] ]
self.metric = self.load_metric(name) self.load_metric(name)
def load_metric(name): def load_metric(self, name):
if name in self.authorized_metric: if name in self.authorized_metric:
self.metric = eval("sklearn.metric.{name}") if name == "precision_score":
self.metric = sklearn.metrics.precision_score
if name == "jaccard_score":
self.metric = sklearn.metrics.jaccard_score
if name == "roc_auc_score":
self.metric = sklearn.metrics.roc_auc_score
if name == "f1_score":
self.metric = sklearn.metrics.f1_score
if name == "accuracy_score":
self.metric = sklearn.metrics.accuracy_score
else: else:
raise ValueError(f"{name} is not supported") raise ValueError(f"{name} is not supported")
def forward(self, mask, target: Tensor) -> float: def forward(self, mask: torch.Tensor, target: torch.Tensor) -> float:
if mask.type() == torch.bool and target.type() == torch.bool: if self.name == "roc_auc_score":
return self.metric(y_score=mask, y_true=target)
else:
return self.metric(y_pred=mask, y_true=target) return self.metric(y_pred=mask, y_true=target)

View file

@ -1,11 +1,14 @@
from abc import ABC from abc import ABC, abstractmethod
import torch
from torch_geometric.explain.explanation import Explanation
class Metric(ABC): class Metric(ABC):
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs): def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
self.name = name self.name = name
self.model = model self.model = model
if is_model_needed and model is None: if self.is_model_needed() and model is None:
raise ValueError(f"{self.name} needs model to perform measurements") raise ValueError(f"{self.name} needs model to perform measurements")
self.authorized_metric = None self.authorized_metric = None
@ -20,7 +23,7 @@ class Metric(ABC):
pass pass
@abstractmethod @abstractmethod
def __call__(self, exp: Explanation, **kwargs) -> float: def forward(exp: Explanation):
pass pass
def get_prediction(self, *args, **kwargs) -> torch.Tensor: def get_prediction(self, *args, **kwargs) -> torch.Tensor:

View file

@ -1,8 +1,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from explaining_framework.metric.base import Metric
from torch.nn import KLDivLoss, Softmax from torch.nn import KLDivLoss, Softmax
from torch_geometric.explain.explanation import Explanation
from base import Metric
print("NUM CLASSES cfg dataset") print("NUM CLASSES cfg dataset")
NUM_CLASS = 5 NUM_CLASS = 5
@ -10,12 +10,15 @@ NUM_CLASS = 5
def softmax(data): def softmax(data):
return Softmax(dim=1)(data) return Softmax(dim=1)(data)
def kl(data1,data2):
return KLDivLoss(dim=1)(data1,data2)
def kl(data1, data2):
kld = KLDivLoss(reduction="batchmean")
return kld(data1, data2)
class Fidelity(Metric): class Fidelity(Metric):
def __init__(name: str, model: torch.nn.Module, mask_type: str): def __init__(self, name: str, model: torch.nn.Module):
super().__init__(name=name, model=model) super().__init__(name=name, model=model)
self.authorized_metric = [ self.authorized_metric = [
"fidelity_plus", "fidelity_plus",
@ -52,8 +55,8 @@ class Fidelity(Metric):
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1) inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1) inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
return torch.mean( return torch.mean(
(exp.y == inferred_class_initial).long() (exp.y == inferred_class_initial).float()
- (exp.y == inferred_class_exp).long() - (exp.y == inferred_class_exp).float()
).item() ).item()
def _fidelity_minus(self, exp: Explanation) -> float: def _fidelity_minus(self, exp: Explanation) -> float:
@ -61,8 +64,8 @@ class Fidelity(Metric):
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1) inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1) inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
return torch.mean( return torch.mean(
(exp.y == inferred_class_initial).long() (exp.y == inferred_class_initial).float()
- (exp.y == inferred_class_exp).long() - (exp.y == inferred_class_exp).float()
).item() ).item()
def _fidelity_plus_prob(self, exp: Explanation) -> float: def _fidelity_plus_prob(self, exp: Explanation) -> float:
@ -92,22 +95,25 @@ class Fidelity(Metric):
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1) torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
).item() ).item()
def _infidelity_KL(self, exp:Explanation) -> float: def _infidelity_KL(self, exp: Explanation) -> float:
self._score_check() self._score_check()
prob_initial = softmax(self.s_initial_data) prob_initial = softmax(self.s_initial_data)
prob_exp = softmax(self.s_exp_sub) prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item() print(prob_initial, prob_exp)
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
def score(self, exp): def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph() self.exp_sub = exp.get_explanation_subgraph()
self.exp_sub_c = exp.get_complement_subgraph() self.exp_sub_c = exp.get_complement_subgraph()
self.s_exp_sub = self.get_prediction(self.exp_sub) self.s_exp_sub = self.get_prediction(
self.s_exp_sub_c = self.get_prediction(self.exp_sub_c) x=self.exp_sub.x, edge_index=self.exp_sub.edge_index
self.s_initial_data = self.get_prediction(exp) )
self.s_exp_sub_c = self.get_prediction(
x=self.exp_sub_c.x, edge_index=self.exp_sub_c.edge_index
)
self.s_initial_data = self.get_prediction(x=exp.x, edge_index=exp.edge_index)
def load_metric(name): def load_metric(self, name):
if name in self.authorized_metric: if name in self.authorized_metric:
if name == "fidelity_plus": if name == "fidelity_plus":
self.metric = lambda exp: self._fidelity_plus(exp) self.metric = lambda exp: self._fidelity_plus(exp)
@ -117,12 +123,12 @@ class Fidelity(Metric):
self.metric = lambda exp: self._fidelity_plus_prob(exp) self.metric = lambda exp: self._fidelity_plus_prob(exp)
if name == "fidelity_minus_prob": if name == "fidelity_minus_prob":
self.metric = lambda exp: self._fidelity_minus_prob(exp) self.metric = lambda exp: self._fidelity_minus_prob(exp)
if name == "_infidelity_KL": if name == "infidelity_KL":
self.metric = lambda exp: self._infidelity_KL(exp) self.metric = lambda exp: self._infidelity_KL(exp)
else: else:
raise ValueError(f"{name} is not supported") raise ValueError(f"{name} is not supported")
return self.metric return self.metric
def __call__(self, exp: Explanation): def forward(self, exp: Explanation):
self.score(exp) self.score(exp)
return self.metric(exp) return self.metric(exp)

View file

@ -1,62 +1,108 @@
import copy import copy
import torch
import torch.nn.functional as F
from explaining_framework.metric.base import Metric
from torch_geometric.explain.explanation import Explanation
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
def compute_gradient(model,input,target, loss): def compute_gradient(model, inp, target, loss):
with torch.autograd.set_grad_enabled(True): with torch.autograd.set_grad_enabled(True):
out = model(input) inp.x.requires_grad = True
err = loss(out,target) out = model(x=inp.x, edge_index=inp.edge_index)
return torch.autograd.grad(err,input) err = loss(out, target)
return torch.autograd.grad(err, inp.x)[0]
class FGSM(object): class FGSM(object):
def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")): def __init__(
self,
model: torch.nn.Module,
loss: torch.nn.Module,
lower_bound: float = float("-inf"),
upper_bound: float = float("inf"),
):
self.model = model self.model = model
self.loss = loss self.loss = loss
self.lower_bound = lower_bound self.lower_bound = lower_bound
self.upper_bound = upper_bound self.upper_bound = upper_bound
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
self.bound = lambda x: torch.clamp(
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
)
self.zero_thresh = 10**-6 self.zero_thresh = 10**-6
def forward(self, input, target, epsilon:float) -> Explanation: def forward(self, input, target, epsilon: float) -> Explanation:
grad = compute_gradient(model=self.model,input=input, target=target, loss=self.loss) input_ = copy.copy(input)
grad = compute_gradient(
model=self.model, inp=input_, target=target, loss=self.loss
)
grad = self.bound(grad) grad = self.bound(grad)
out = torch.where(torch.abs(grad) > self.zero_thresh,input - epsilon * torch.sign(grad),input) input_.x = torch.where(
return out torch.abs(grad) > self.zero_thresh,
input_.x - epsilon * torch.sign(grad),
input_.x,
)
return input_
class PGD(object): class PGD(object):
def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")): def __init__(
self,
model: torch.nn.Module,
loss: torch.nn.Module,
lower_bound: float = float("-inf"),
upper_bound: float = float("inf"),
):
self.model = model self.model = model
self.loss = loss self.loss = loss
self.lower_bound = lower_bound self.lower_bound = lower_bound
self.upper_bound = upper_bound self.upper_bound = upper_bound
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound) self.bound = lambda x: torch.clamp(
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
)
self.zero_thresh = 10**-6 self.zero_thresh = 10**-6
self.fgsm = FGSM(model=model,loss=loss,lower_bound=lower_bound,upper_bound=upper_bound) self.fgsm = FGSM(
model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound
)
def forward(self, input, target, epsilon:float, radius:float, step_num:int, random_start:bool = False, norm:str='inf') -> Explanation: def forward(
diff = outputs - inputs self,
if norm == "inf": input,
return inputs + torch.clamp(diff, -radius, radius) target,
elif norm == "2": epsilon: float,
return inputs + torch.renorm(diff, 2, 0, radius) radius: float,
else: step_num: int,
raise AssertionError("Norm constraint must be 2 or inf.") random_start: bool = False,
norm: str = "inf",
) -> Explanation:
def _clip(inputs: Explanation, outputs: Explanation) -> Explanation:
diff = outputs.x - inputs.x
if norm == "inf":
inputs.x = inputs.x + torch.clamp(diff, -radius, radius)
return inputs
elif norm == "2":
inputs.x = inputs.x + torch.renorm(diff, 2, 0, radius)
return inputs
else:
raise AssertionError("Norm constraint must be L2 or Linf.")
perturbed_inputs = input perturbed_input = input
if random_start: if random_start:
perturbed_inputs= self.bound(self._random_point(input, radius, norm)) perturbed_input = self.bound(self._random_point(input.x, radius, norm))
for _ in range(step_num): for _ in range(step_num):
perturbed_inputs = self.fgsm.perturb( perturbed_input = self.fgsm.forward(
input=perturbed_inputs, epsilon=epsilon, target=target input=perturbed_input, epsilon=epsilon, target=target
) )
perturbed_inputs = self.forward(input, perturbed_inputs) perturbed_input = _clip(input, perturbed_input)
perturbed_inputs = self.bound(perturbed_inputs).detach() perturbed_input.x = self.bound(perturbed_input.x).detach()
return perturbed_inputs return perturbed_input
def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor: def _random_point(
self, center: torch.Tensor, radius: float, norm: str
) -> torch.Tensor:
r""" r"""
A helper function that returns a uniform random point within the ball A helper function that returns a uniform random point within the ball
with the given center and radius. Norm should be either L2 or Linf. with the given center and radius. Norm should be either L2 or Linf.
@ -76,9 +122,14 @@ class PGD(object):
raise AssertionError("Norm constraint must be L2 or Linf.") raise AssertionError("Norm constraint must be L2 or Linf.")
class Attack(Metric): class Attack(Metric):
def __init__(name: str, model: torch.nn.Module, dropout:float = 0.5): def __init__(
self,
name: str,
model: torch.nn.Module,
dropout: float = 0.5,
loss: torch.nn = None,
):
super().__init__(name=name, model=model) super().__init__(name=name, model=model)
self.name = name self.name = name
self.model = model self.model = model
@ -86,70 +137,85 @@ class Attack(Metric):
"gaussian_noise", "gaussian_noise",
"add_edge", "add_edge",
"remove_edge", "remove_edge",
"remove_node" "remove_node",
"pgd", "pgd",
"fgsm", "fgsm",
] ]
self.dropout = dropout self.dropout = dropout
self._load_metric(name) self.loss = loss
self.load_metric(name)
def _gaussian_noise(self,exp) -> Explanation: def _gaussian_noise(self, exp) -> Explanation:
x= torch.clone(exp.x) x = torch.clone(exp.x)
x=x+torch.randn(*x.shape) x = x + torch.randn(*x.shape)
exp_ = copy.copy(exp) exp_ = copy.copy(exp)
exp_.x = x exp_.x = x
return exp_ return exp_
def _add_edge(self,exp,p:float) -> Explanation: def _add_edge(self, exp, p: float) -> Explanation:
exp_ = copy.copy(exp) exp_ = copy.copy(exp)
exp_.edge_index, _ = add_random_edge(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0]) exp_.edge_index, _ = add_random_edge(
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
)
return exp_ return exp_
def _remove_edge(self,exp,p:float) -> Explanation: def _remove_edge(self, exp, p: float) -> Explanation:
exp_ = copy.copy(exp) exp_ = copy.copy(exp)
exp_.edge_index, _ = dropout_edge(exp_.edge_index,p=p) exp_.edge_index, _ = dropout_edge(exp_.edge_index, p=p)
return exp_ return exp_
def _remove_node(self,exp,p:float) -> Explanation: def _remove_node(self, exp, p: float) -> Explanation:
exp_ = copy.copy(exp) exp_ = copy.copy(exp)
exp_.edge_index, _ = dropout_node(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0]) exp_.edge_index, _, _ = dropout_node(
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
)
return exp_ return exp_
def _load_add_edge(self): def _load_add_edge(self):
return lambda exp : self._add_edge(exp,p=self.dropout) return lambda exp: self._add_edge(exp, p=self.dropout)
def _load_remove_edge(self): def _load_remove_edge(self):
return lambda exp : self._remove_edge(exp,p=self.dropout) return lambda exp: self._remove_edge(exp, p=self.dropout)
def _load_remove_node(self): def _load_remove_node(self):
return lambda exp : self._remove_node(exp,p=self.dropout) return lambda exp: self._remove_node(exp, p=self.dropout)
def _load_gaussian_noise(self): def _load_gaussian_noise(self):
return lambda exp: self._gaussian_noise(exp) return lambda exp: self._gaussian_noise(exp)
def _load_metric(self): def load_metric(self, name):
if name in self.authorized_metric: if name in self.authorized_metric:
if name == "gaussian_noise": if name == "gaussian_noise":
self.metric= self._load_gaussian_noise() self.metric = self._load_gaussian_noise()
if name == "add_edge": if name == "add_edge":
self.metric=self._load_add_edge() self.metric = self._load_add_edge()
if name == "remove_edge": if name == "remove_edge":
self.metric= self._load_remove_edge() self.metric = self._load_remove_edge()
if name == "remove_node": if name == "remove_node":
self.metric= self._load_remove_node() self.metric = self._load_remove_node()
if name== "pgd": if name == "pgd":
print('set LOSS with cfg ') print("set LOSS with cfg ")
pgd = PGD(model=self.model,loss=LOSS) pgd = PGD(model=self.model, loss=self.loss)
self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1,radius=1, step_num = 50, random_start=False, norm = 'inf') self.metric = lambda exp: pgd.forward(
if name== "fgsm": input=exp,
print('set LOSS with cfg ') target=exp.y,
pgd = FGSM(model=self.model,loss=LOSS) epsilon=1,
self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1) radius=1,
step_num=50,
random_start=False,
norm="inf",
)
if name == "fgsm":
print("set LOSS with cfg ")
fgsm = FGSM(model=self.model, loss=self.loss)
self.metric = lambda exp: fgsm.forward(
input=exp, target=exp.y, epsilon=1
)
else: else:
raise ValueError(f'{name} is not supported yet') raise ValueError(f"{name} is not supported yet")
return self.metric return self.metric
def forward(self,exp) -> Explanation: def forward(self, exp) -> Explanation:
attack = self.metric(exp) attack = self.metric(exp)
return attack return attack

View file

@ -1,9 +1,10 @@
import torch import torch
from explaining_framework.metric.base import Metric
class Sparsity(Metric): class Sparsity(Metric):
def __init__(self,name): def __init__(self, name):
super().__init__(name=name,model=None) super().__init__(name=name, model=None)
def forward(self, mask): def forward(self, mask):
return torch.mean(mask.float()).item() return torch.mean(mask.float()).item()