Fixing many bugs
This commit is contained in:
parent
03b185c142
commit
55570a26e5
|
@ -2,10 +2,10 @@ import traceback
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from explaining_framework.explaining_framework.metric.accuracy import Accuracy
|
||||
from explaining_framework.explaining_framework.metric.fidelity import Fidelity
|
||||
from explaining_framework.explaining_framework.metric.robust import Attack
|
||||
from explaining_framework.explaining_framework.metric.sparsity import Sparsity
|
||||
from explaining_framework.metric.accuracy import Accuracy
|
||||
from explaining_framework.metric.fidelity import Fidelity
|
||||
from explaining_framework.metric.robust import Attack
|
||||
from explaining_framework.metric.sparsity import Sparsity
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
||||
|
@ -113,26 +113,35 @@ for kind in ["graph"]:
|
|||
index=int(target),
|
||||
target=batch.y,
|
||||
)
|
||||
print(explanation.__dict__)
|
||||
# explanation.__setattr__(
|
||||
# "model_prediction", explainer.get_prediction(x, edge_index)
|
||||
# )
|
||||
explanation_threshold = explanation._apply_mask(
|
||||
node_mask=explanation.node_mask, edge_mask=explanation.edge_mask
|
||||
explanation_threshold = explanation._apply_masks(
|
||||
node_mask=torch.ones_like(explanation.node_mask).bool()
|
||||
)
|
||||
|
||||
print(explanation_threshold.__dict__)
|
||||
|
||||
for f_name in [
|
||||
"precision_score",
|
||||
"precision_score",
|
||||
"jaccard_score",
|
||||
"roc_auc_score",
|
||||
"f1_score",
|
||||
"accuracy_score",
|
||||
"gaussian_noise",
|
||||
"add_edge",
|
||||
"remove_edge",
|
||||
"remove_node",
|
||||
"pgd",
|
||||
"fgsm",
|
||||
]:
|
||||
acc = Accuracy(f_name)
|
||||
gt = torch.ones_like(x) / 2
|
||||
out = acc.forward(mask=explanation_threshold.node_mask, target=gt)
|
||||
print(f_name)
|
||||
acc = Attack(name=f_name, model=model, loss=loss)
|
||||
# gt = torch.ones_like(explanation_threshold.node_mask) / 2
|
||||
# mask = explanation_threshold.node_mask.bool()
|
||||
# target = (1 - gt).bool()
|
||||
# target[1] = False
|
||||
# print(mask, target)
|
||||
out = acc.forward(explanation)
|
||||
print(out)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
# print(str(e))
|
||||
pass
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
import sklearn.metrics
|
||||
import torch
|
||||
|
||||
from base import Metric
|
||||
from explaining_framework.metric.base import Metric
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
def __init__(name: str):
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name=name, model=None)
|
||||
self.authorized_metric = [
|
||||
"precision_score",
|
||||
"precision_score",
|
||||
"jaccard_score",
|
||||
"roc_auc_score",
|
||||
|
@ -16,14 +14,25 @@ class Accuracy(Metric):
|
|||
"accuracy_score",
|
||||
]
|
||||
|
||||
self.metric = self.load_metric(name)
|
||||
self.load_metric(name)
|
||||
|
||||
def load_metric(name):
|
||||
def load_metric(self, name):
|
||||
if name in self.authorized_metric:
|
||||
self.metric = eval("sklearn.metric.{name}")
|
||||
if name == "precision_score":
|
||||
self.metric = sklearn.metrics.precision_score
|
||||
if name == "jaccard_score":
|
||||
self.metric = sklearn.metrics.jaccard_score
|
||||
if name == "roc_auc_score":
|
||||
self.metric = sklearn.metrics.roc_auc_score
|
||||
if name == "f1_score":
|
||||
self.metric = sklearn.metrics.f1_score
|
||||
if name == "accuracy_score":
|
||||
self.metric = sklearn.metrics.accuracy_score
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
|
||||
def forward(self, mask, target: Tensor) -> float:
|
||||
if mask.type() == torch.bool and target.type() == torch.bool:
|
||||
def forward(self, mask: torch.Tensor, target: torch.Tensor) -> float:
|
||||
if self.name == "roc_auc_score":
|
||||
return self.metric(y_score=mask, y_true=target)
|
||||
else:
|
||||
return self.metric(y_pred=mask, y_true=target)
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
|
||||
self.name = name
|
||||
self.model = model
|
||||
if is_model_needed and model is None:
|
||||
if self.is_model_needed() and model is None:
|
||||
raise ValueError(f"{self.name} needs model to perform measurements")
|
||||
self.authorized_metric = None
|
||||
|
||||
|
@ -20,7 +23,7 @@ class Metric(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, exp: Explanation, **kwargs) -> float:
|
||||
def forward(exp: Explanation):
|
||||
pass
|
||||
|
||||
def get_prediction(self, *args, **kwargs) -> torch.Tensor:
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from explaining_framework.metric.base import Metric
|
||||
from torch.nn import KLDivLoss, Softmax
|
||||
|
||||
from base import Metric
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
|
||||
print("NUM CLASSES cfg dataset")
|
||||
NUM_CLASS = 5
|
||||
|
@ -10,12 +10,15 @@ NUM_CLASS = 5
|
|||
|
||||
def softmax(data):
|
||||
return Softmax(dim=1)(data)
|
||||
def kl(data1,data2):
|
||||
return KLDivLoss(dim=1)(data1,data2)
|
||||
|
||||
|
||||
def kl(data1, data2):
|
||||
kld = KLDivLoss(reduction="batchmean")
|
||||
return kld(data1, data2)
|
||||
|
||||
|
||||
class Fidelity(Metric):
|
||||
def __init__(name: str, model: torch.nn.Module, mask_type: str):
|
||||
def __init__(self, name: str, model: torch.nn.Module):
|
||||
super().__init__(name=name, model=model)
|
||||
self.authorized_metric = [
|
||||
"fidelity_plus",
|
||||
|
@ -52,8 +55,8 @@ class Fidelity(Metric):
|
|||
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
|
||||
return torch.mean(
|
||||
(exp.y == inferred_class_initial).long()
|
||||
- (exp.y == inferred_class_exp).long()
|
||||
(exp.y == inferred_class_initial).float()
|
||||
- (exp.y == inferred_class_exp).float()
|
||||
).item()
|
||||
|
||||
def _fidelity_minus(self, exp: Explanation) -> float:
|
||||
|
@ -61,8 +64,8 @@ class Fidelity(Metric):
|
|||
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
|
||||
return torch.mean(
|
||||
(exp.y == inferred_class_initial).long()
|
||||
- (exp.y == inferred_class_exp).long()
|
||||
(exp.y == inferred_class_initial).float()
|
||||
- (exp.y == inferred_class_exp).float()
|
||||
).item()
|
||||
|
||||
def _fidelity_plus_prob(self, exp: Explanation) -> float:
|
||||
|
@ -92,22 +95,25 @@ class Fidelity(Metric):
|
|||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||
).item()
|
||||
|
||||
def _infidelity_KL(self, exp:Explanation) -> float:
|
||||
self._score_check()
|
||||
prob_initial = softmax(self.s_initial_data)
|
||||
prob_exp = softmax(self.s_exp_sub)
|
||||
return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item()
|
||||
|
||||
|
||||
def _infidelity_KL(self, exp: Explanation) -> float:
|
||||
self._score_check()
|
||||
prob_initial = softmax(self.s_initial_data)
|
||||
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
|
||||
print(prob_initial, prob_exp)
|
||||
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
|
||||
|
||||
def score(self, exp):
|
||||
self.exp_sub = exp.get_explanation_subgraph()
|
||||
self.exp_sub_c = exp.get_complement_subgraph()
|
||||
self.s_exp_sub = self.get_prediction(self.exp_sub)
|
||||
self.s_exp_sub_c = self.get_prediction(self.exp_sub_c)
|
||||
self.s_initial_data = self.get_prediction(exp)
|
||||
self.s_exp_sub = self.get_prediction(
|
||||
x=self.exp_sub.x, edge_index=self.exp_sub.edge_index
|
||||
)
|
||||
self.s_exp_sub_c = self.get_prediction(
|
||||
x=self.exp_sub_c.x, edge_index=self.exp_sub_c.edge_index
|
||||
)
|
||||
self.s_initial_data = self.get_prediction(x=exp.x, edge_index=exp.edge_index)
|
||||
|
||||
def load_metric(name):
|
||||
def load_metric(self, name):
|
||||
if name in self.authorized_metric:
|
||||
if name == "fidelity_plus":
|
||||
self.metric = lambda exp: self._fidelity_plus(exp)
|
||||
|
@ -117,12 +123,12 @@ class Fidelity(Metric):
|
|||
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
||||
if name == "fidelity_minus_prob":
|
||||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||
if name == "_infidelity_KL":
|
||||
if name == "infidelity_KL":
|
||||
self.metric = lambda exp: self._infidelity_KL(exp)
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
return self.metric
|
||||
|
||||
def __call__(self, exp: Explanation):
|
||||
def forward(self, exp: Explanation):
|
||||
self.score(exp)
|
||||
return self.metric(exp)
|
||||
|
|
|
@ -1,62 +1,108 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from explaining_framework.metric.base import Metric
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
||||
|
||||
|
||||
def compute_gradient(model,input,target, loss):
|
||||
def compute_gradient(model, inp, target, loss):
|
||||
with torch.autograd.set_grad_enabled(True):
|
||||
out = model(input)
|
||||
err = loss(out,target)
|
||||
return torch.autograd.grad(err,input)
|
||||
|
||||
inp.x.requires_grad = True
|
||||
out = model(x=inp.x, edge_index=inp.edge_index)
|
||||
err = loss(out, target)
|
||||
return torch.autograd.grad(err, inp.x)[0]
|
||||
|
||||
|
||||
class FGSM(object):
|
||||
def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
loss: torch.nn.Module,
|
||||
lower_bound: float = float("-inf"),
|
||||
upper_bound: float = float("inf"),
|
||||
):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
|
||||
|
||||
self.bound = lambda x: torch.clamp(
|
||||
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
|
||||
)
|
||||
|
||||
self.zero_thresh = 10**-6
|
||||
|
||||
def forward(self, input, target, epsilon:float) -> Explanation:
|
||||
grad = compute_gradient(model=self.model,input=input, target=target, loss=self.loss)
|
||||
def forward(self, input, target, epsilon: float) -> Explanation:
|
||||
input_ = copy.copy(input)
|
||||
grad = compute_gradient(
|
||||
model=self.model, inp=input_, target=target, loss=self.loss
|
||||
)
|
||||
grad = self.bound(grad)
|
||||
out = torch.where(torch.abs(grad) > self.zero_thresh,input - epsilon * torch.sign(grad),input)
|
||||
return out
|
||||
input_.x = torch.where(
|
||||
torch.abs(grad) > self.zero_thresh,
|
||||
input_.x - epsilon * torch.sign(grad),
|
||||
input_.x,
|
||||
)
|
||||
return input_
|
||||
|
||||
|
||||
class PGD(object):
|
||||
def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
loss: torch.nn.Module,
|
||||
lower_bound: float = float("-inf"),
|
||||
upper_bound: float = float("inf"),
|
||||
):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
|
||||
self.bound = lambda x: torch.clamp(
|
||||
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
|
||||
)
|
||||
self.zero_thresh = 10**-6
|
||||
self.fgsm = FGSM(model=model,loss=loss,lower_bound=lower_bound,upper_bound=upper_bound)
|
||||
self.fgsm = FGSM(
|
||||
model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound
|
||||
)
|
||||
|
||||
def forward(self, input, target, epsilon:float, radius:float, step_num:int, random_start:bool = False, norm:str='inf') -> Explanation:
|
||||
diff = outputs - inputs
|
||||
if norm == "inf":
|
||||
return inputs + torch.clamp(diff, -radius, radius)
|
||||
elif norm == "2":
|
||||
return inputs + torch.renorm(diff, 2, 0, radius)
|
||||
else:
|
||||
raise AssertionError("Norm constraint must be 2 or inf.")
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
target,
|
||||
epsilon: float,
|
||||
radius: float,
|
||||
step_num: int,
|
||||
random_start: bool = False,
|
||||
norm: str = "inf",
|
||||
) -> Explanation:
|
||||
def _clip(inputs: Explanation, outputs: Explanation) -> Explanation:
|
||||
diff = outputs.x - inputs.x
|
||||
if norm == "inf":
|
||||
inputs.x = inputs.x + torch.clamp(diff, -radius, radius)
|
||||
return inputs
|
||||
elif norm == "2":
|
||||
inputs.x = inputs.x + torch.renorm(diff, 2, 0, radius)
|
||||
return inputs
|
||||
else:
|
||||
raise AssertionError("Norm constraint must be L2 or Linf.")
|
||||
|
||||
perturbed_inputs = input
|
||||
perturbed_input = input
|
||||
if random_start:
|
||||
perturbed_inputs= self.bound(self._random_point(input, radius, norm))
|
||||
perturbed_input = self.bound(self._random_point(input.x, radius, norm))
|
||||
for _ in range(step_num):
|
||||
perturbed_inputs = self.fgsm.perturb(
|
||||
input=perturbed_inputs, epsilon=epsilon, target=target
|
||||
perturbed_input = self.fgsm.forward(
|
||||
input=perturbed_input, epsilon=epsilon, target=target
|
||||
)
|
||||
perturbed_inputs = self.forward(input, perturbed_inputs)
|
||||
perturbed_inputs = self.bound(perturbed_inputs).detach()
|
||||
return perturbed_inputs
|
||||
perturbed_input = _clip(input, perturbed_input)
|
||||
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
||||
return perturbed_input
|
||||
|
||||
def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor:
|
||||
def _random_point(
|
||||
self, center: torch.Tensor, radius: float, norm: str
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
A helper function that returns a uniform random point within the ball
|
||||
with the given center and radius. Norm should be either L2 or Linf.
|
||||
|
@ -76,9 +122,14 @@ class PGD(object):
|
|||
raise AssertionError("Norm constraint must be L2 or Linf.")
|
||||
|
||||
|
||||
|
||||
class Attack(Metric):
|
||||
def __init__(name: str, model: torch.nn.Module, dropout:float = 0.5):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model: torch.nn.Module,
|
||||
dropout: float = 0.5,
|
||||
loss: torch.nn = None,
|
||||
):
|
||||
super().__init__(name=name, model=model)
|
||||
self.name = name
|
||||
self.model = model
|
||||
|
@ -86,70 +137,85 @@ class Attack(Metric):
|
|||
"gaussian_noise",
|
||||
"add_edge",
|
||||
"remove_edge",
|
||||
"remove_node"
|
||||
"remove_node",
|
||||
"pgd",
|
||||
"fgsm",
|
||||
]
|
||||
self.dropout = dropout
|
||||
self._load_metric(name)
|
||||
self.loss = loss
|
||||
self.load_metric(name)
|
||||
|
||||
def _gaussian_noise(self,exp) -> Explanation:
|
||||
x= torch.clone(exp.x)
|
||||
x=x+torch.randn(*x.shape)
|
||||
def _gaussian_noise(self, exp) -> Explanation:
|
||||
x = torch.clone(exp.x)
|
||||
x = x + torch.randn(*x.shape)
|
||||
exp_ = copy.copy(exp)
|
||||
exp_.x = x
|
||||
return exp_
|
||||
|
||||
def _add_edge(self,exp,p:float) -> Explanation:
|
||||
def _add_edge(self, exp, p: float) -> Explanation:
|
||||
exp_ = copy.copy(exp)
|
||||
exp_.edge_index, _ = add_random_edge(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0])
|
||||
exp_.edge_index, _ = add_random_edge(
|
||||
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
||||
)
|
||||
return exp_
|
||||
|
||||
def _remove_edge(self,exp,p:float) -> Explanation:
|
||||
def _remove_edge(self, exp, p: float) -> Explanation:
|
||||
exp_ = copy.copy(exp)
|
||||
exp_.edge_index, _ = dropout_edge(exp_.edge_index,p=p)
|
||||
exp_.edge_index, _ = dropout_edge(exp_.edge_index, p=p)
|
||||
return exp_
|
||||
|
||||
def _remove_node(self,exp,p:float) -> Explanation:
|
||||
def _remove_node(self, exp, p: float) -> Explanation:
|
||||
exp_ = copy.copy(exp)
|
||||
exp_.edge_index, _ = dropout_node(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0])
|
||||
exp_.edge_index, _, _ = dropout_node(
|
||||
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
||||
)
|
||||
return exp_
|
||||
|
||||
def _load_add_edge(self):
|
||||
return lambda exp : self._add_edge(exp,p=self.dropout)
|
||||
return lambda exp: self._add_edge(exp, p=self.dropout)
|
||||
|
||||
def _load_remove_edge(self):
|
||||
return lambda exp : self._remove_edge(exp,p=self.dropout)
|
||||
return lambda exp: self._remove_edge(exp, p=self.dropout)
|
||||
|
||||
def _load_remove_node(self):
|
||||
return lambda exp : self._remove_node(exp,p=self.dropout)
|
||||
return lambda exp: self._remove_node(exp, p=self.dropout)
|
||||
|
||||
def _load_gaussian_noise(self):
|
||||
return lambda exp: self._gaussian_noise(exp)
|
||||
|
||||
def _load_metric(self):
|
||||
def load_metric(self, name):
|
||||
if name in self.authorized_metric:
|
||||
if name == "gaussian_noise":
|
||||
self.metric= self._load_gaussian_noise()
|
||||
self.metric = self._load_gaussian_noise()
|
||||
if name == "add_edge":
|
||||
self.metric=self._load_add_edge()
|
||||
self.metric = self._load_add_edge()
|
||||
if name == "remove_edge":
|
||||
self.metric= self._load_remove_edge()
|
||||
self.metric = self._load_remove_edge()
|
||||
if name == "remove_node":
|
||||
self.metric= self._load_remove_node()
|
||||
if name== "pgd":
|
||||
print('set LOSS with cfg ')
|
||||
pgd = PGD(model=self.model,loss=LOSS)
|
||||
self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1,radius=1, step_num = 50, random_start=False, norm = 'inf')
|
||||
if name== "fgsm":
|
||||
print('set LOSS with cfg ')
|
||||
pgd = FGSM(model=self.model,loss=LOSS)
|
||||
self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1)
|
||||
self.metric = self._load_remove_node()
|
||||
if name == "pgd":
|
||||
print("set LOSS with cfg ")
|
||||
pgd = PGD(model=self.model, loss=self.loss)
|
||||
self.metric = lambda exp: pgd.forward(
|
||||
input=exp,
|
||||
target=exp.y,
|
||||
epsilon=1,
|
||||
radius=1,
|
||||
step_num=50,
|
||||
random_start=False,
|
||||
norm="inf",
|
||||
)
|
||||
if name == "fgsm":
|
||||
print("set LOSS with cfg ")
|
||||
fgsm = FGSM(model=self.model, loss=self.loss)
|
||||
self.metric = lambda exp: fgsm.forward(
|
||||
input=exp, target=exp.y, epsilon=1
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'{name} is not supported yet')
|
||||
raise ValueError(f"{name} is not supported yet")
|
||||
|
||||
return self.metric
|
||||
|
||||
def forward(self,exp) -> Explanation:
|
||||
def forward(self, exp) -> Explanation:
|
||||
attack = self.metric(exp)
|
||||
return attack
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
from explaining_framework.metric.base import Metric
|
||||
|
||||
|
||||
class Sparsity(Metric):
|
||||
def __init__(self,name):
|
||||
super().__init__(name=name,model=None)
|
||||
def __init__(self, name):
|
||||
super().__init__(name=name, model=None)
|
||||
|
||||
def forward(self, mask):
|
||||
return torch.mean(mask.float()).item()
|
||||
|
|
Loading…
Reference in New Issue