Aborting
This commit is contained in:
commit
45d39edbd9
|
@ -2,6 +2,10 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from explaining_framework.explaining_framework.metric.accuracy import Accuracy
|
||||||
|
from explaining_framework.explaining_framework.metric.fidelity import Fidelity
|
||||||
|
from explaining_framework.explaining_framework.metric.robust import Attack
|
||||||
|
from explaining_framework.explaining_framework.metric.sparsity import Sparsity
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
||||||
|
@ -27,16 +31,16 @@ __all__captum = [
|
||||||
|
|
||||||
__all__graphxai = [
|
__all__graphxai = [
|
||||||
"CAM",
|
"CAM",
|
||||||
"GradCAM",
|
# "GradCAM",
|
||||||
"GNN_LRP",
|
# "GNN_LRP",
|
||||||
"GradExplainer",
|
# "GradExplainer",
|
||||||
"GuidedBackPropagation",
|
# "GuidedBackPropagation",
|
||||||
"IntegratedGradients",
|
# "IntegratedGradients",
|
||||||
"PGExplainer",
|
# "PGExplainer",
|
||||||
"PGMExplainer",
|
# "PGMExplainer",
|
||||||
"RandomExplainer",
|
# "RandomExplainer",
|
||||||
"SubgraphX",
|
# "SubgraphX",
|
||||||
"GraphMASK",
|
# "GraphMASK",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,9 +81,8 @@ for epoch in range(1, 2):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
target = torch.LongTensor([[0]])
|
target = torch.LongTensor([[0]])
|
||||||
|
|
||||||
for kind in ["node", "graph"]:
|
for kind in ["graph"]:
|
||||||
print(kind)
|
for name in __all__graphxai:
|
||||||
for name in __all__captum + __all__graphxai:
|
|
||||||
if name in __all__captum:
|
if name in __all__captum:
|
||||||
explaining_algorithm = CaptumWrapper(name)
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
elif name in __all__graphxai:
|
elif name in __all__graphxai:
|
||||||
|
@ -102,6 +105,7 @@ for kind in ["node", "graph"]:
|
||||||
task_level=kind,
|
task_level=kind,
|
||||||
return_type="raw",
|
return_type="raw",
|
||||||
),
|
),
|
||||||
|
threshold_config=dict(threshold_type="hard", value=0.5),
|
||||||
)
|
)
|
||||||
explanation = explainer(
|
explanation = explainer(
|
||||||
x=batch.x,
|
x=batch.x,
|
||||||
|
@ -109,10 +113,26 @@ for kind in ["node", "graph"]:
|
||||||
index=int(target),
|
index=int(target),
|
||||||
target=batch.y,
|
target=batch.y,
|
||||||
)
|
)
|
||||||
explanation.__setattr__(
|
# explanation.__setattr__(
|
||||||
"model_prediction", explainer.get_prediction(x, edge_index)
|
# "model_prediction", explainer.get_prediction(x, edge_index)
|
||||||
|
# )
|
||||||
|
explanation_threshold = explanation._apply_mask(
|
||||||
|
node_mask=explanation.node_mask, edge_mask=explanation.edge_mask
|
||||||
)
|
)
|
||||||
print(explanation.__dict__)
|
|
||||||
|
for f_name in [
|
||||||
|
"precision_score",
|
||||||
|
"precision_score",
|
||||||
|
"jaccard_score",
|
||||||
|
"roc_auc_score",
|
||||||
|
"f1_score",
|
||||||
|
"accuracy_score",
|
||||||
|
]:
|
||||||
|
acc = Accuracy(f_name)
|
||||||
|
gt = torch.ones_like(x) / 2
|
||||||
|
out = acc.forward(mask=explanation_threshold.node_mask, target=gt)
|
||||||
|
print(out)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print(str(e))
|
# print(str(e))
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -10,6 +10,8 @@ NUM_CLASS = 5
|
||||||
|
|
||||||
def softmax(data):
|
def softmax(data):
|
||||||
return Softmax(dim=1)(data)
|
return Softmax(dim=1)(data)
|
||||||
|
def kl(data1,data2):
|
||||||
|
return KLDivLoss(dim=1)(data1,data2)
|
||||||
|
|
||||||
|
|
||||||
class Fidelity(Metric):
|
class Fidelity(Metric):
|
||||||
|
@ -23,12 +25,12 @@ class Fidelity(Metric):
|
||||||
"infidelity_KL",
|
"infidelity_KL",
|
||||||
]
|
]
|
||||||
|
|
||||||
self.metric = self.load_metric(name)
|
|
||||||
self.exp_sub = None
|
self.exp_sub = None
|
||||||
self.exp_sub_c = None
|
self.exp_sub_c = None
|
||||||
self.s_exp_sub = None
|
self.s_exp_sub = None
|
||||||
self.s_exp_sub_c = None
|
self.s_exp_sub_c = None
|
||||||
self.s_initial_data = None
|
self.s_initial_data = None
|
||||||
|
self.metric = self.load_metric(name)
|
||||||
|
|
||||||
def _score_check(self):
|
def _score_check(self):
|
||||||
if any(
|
if any(
|
||||||
|
@ -90,6 +92,14 @@ class Fidelity(Metric):
|
||||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
|
def _infidelity_KL(self, exp:Explanation) -> float:
|
||||||
|
self._score_check()
|
||||||
|
prob_initial = softmax(self.s_initial_data)
|
||||||
|
prob_exp = softmax(self.s_exp_sub)
|
||||||
|
return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def score(self, exp):
|
def score(self, exp):
|
||||||
self.exp_sub = exp.get_explanation_subgraph()
|
self.exp_sub = exp.get_explanation_subgraph()
|
||||||
self.exp_sub_c = exp.get_complement_subgraph()
|
self.exp_sub_c = exp.get_complement_subgraph()
|
||||||
|
@ -107,9 +117,12 @@ class Fidelity(Metric):
|
||||||
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
||||||
if name == "fidelity_minus_prob":
|
if name == "fidelity_minus_prob":
|
||||||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||||
|
if name == "_infidelity_KL":
|
||||||
|
self.metric = lambda exp: self._infidelity_KL(exp)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name} is not supported")
|
raise ValueError(f"{name} is not supported")
|
||||||
return self.metric
|
return self.metric
|
||||||
|
|
||||||
def __call__(self, exp: Explanation):
|
def __call__(self, exp: Explanation):
|
||||||
|
self.score(exp)
|
||||||
return self.metric(exp)
|
return self.metric(exp)
|
||||||
|
|
|
@ -1,4 +1,155 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def compute_gradient(model,input,target, loss):
|
||||||
|
with torch.autograd.set_grad_enabled(True):
|
||||||
|
out = model(input)
|
||||||
|
err = loss(out,target)
|
||||||
|
return torch.autograd.grad(err,input)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FGSM(object):
|
||||||
|
def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")):
|
||||||
|
self.model = model
|
||||||
|
self.loss = loss
|
||||||
|
self.lower_bound = lower_bound
|
||||||
|
self.upper_bound = upper_bound
|
||||||
|
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
|
||||||
|
self.zero_thresh = 10**-6
|
||||||
|
|
||||||
|
def forward(self, input, target, epsilon:float) -> Explanation:
|
||||||
|
grad = compute_gradient(model=self.model,input=input, target=target, loss=self.loss)
|
||||||
|
grad = self.bound(grad)
|
||||||
|
out = torch.where(torch.abs(grad) > self.zero_thresh,input - epsilon * torch.sign(grad),input)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PGD(object):
|
||||||
|
def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")):
|
||||||
|
self.model = model
|
||||||
|
self.loss = loss
|
||||||
|
self.lower_bound = lower_bound
|
||||||
|
self.upper_bound = upper_bound
|
||||||
|
self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
|
||||||
|
self.zero_thresh = 10**-6
|
||||||
|
self.fgsm = FGSM(model=model,loss=loss,lower_bound=lower_bound,upper_bound=upper_bound)
|
||||||
|
|
||||||
|
def forward(self, input, target, epsilon:float, radius:float, step_num:int, random_start:bool = False, norm:str='inf') -> Explanation:
|
||||||
|
diff = outputs - inputs
|
||||||
|
if norm == "inf":
|
||||||
|
return inputs + torch.clamp(diff, -radius, radius)
|
||||||
|
elif norm == "2":
|
||||||
|
return inputs + torch.renorm(diff, 2, 0, radius)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Norm constraint must be 2 or inf.")
|
||||||
|
|
||||||
|
perturbed_inputs = input
|
||||||
|
if random_start:
|
||||||
|
perturbed_inputs= self.bound(self._random_point(input, radius, norm))
|
||||||
|
for _ in range(step_num):
|
||||||
|
perturbed_inputs = self.fgsm.perturb(
|
||||||
|
input=perturbed_inputs, epsilon=epsilon, target=target
|
||||||
|
)
|
||||||
|
perturbed_inputs = self.forward(input, perturbed_inputs)
|
||||||
|
perturbed_inputs = self.bound(perturbed_inputs).detach()
|
||||||
|
return perturbed_inputs
|
||||||
|
|
||||||
|
def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor:
|
||||||
|
r"""
|
||||||
|
A helper function that returns a uniform random point within the ball
|
||||||
|
with the given center and radius. Norm should be either L2 or Linf.
|
||||||
|
"""
|
||||||
|
if norm == "2":
|
||||||
|
u = torch.randn_like(center)
|
||||||
|
unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size())
|
||||||
|
d = torch.numel(center[0])
|
||||||
|
r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius
|
||||||
|
r = r[(...,) + (None,) * (r.dim() - 1)]
|
||||||
|
x = r * unit_u
|
||||||
|
return center + x
|
||||||
|
elif norm == "inf":
|
||||||
|
x = torch.rand_like(center) * radius * 2 - radius
|
||||||
|
return center + x
|
||||||
|
else:
|
||||||
|
raise AssertionError("Norm constraint must be L2 or Linf.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Attack(Metric):
|
class Attack(Metric):
|
||||||
|
def __init__(name: str, model: torch.nn.Module, dropout:float = 0.5):
|
||||||
|
super().__init__(name=name, model=model)
|
||||||
|
self.name = name
|
||||||
|
self.model = model
|
||||||
|
self.authorized_metric = [
|
||||||
|
"gaussian_noise",
|
||||||
|
"add_edge",
|
||||||
|
"remove_edge",
|
||||||
|
"remove_node"
|
||||||
|
"pgd",
|
||||||
|
"fgsm",
|
||||||
|
]
|
||||||
|
self.dropout = dropout
|
||||||
|
self._load_metric(name)
|
||||||
|
|
||||||
|
def _gaussian_noise(self,exp) -> Explanation:
|
||||||
|
x= torch.clone(exp.x)
|
||||||
|
x=x+torch.randn(*x.shape)
|
||||||
|
exp_ = copy.copy(exp)
|
||||||
|
exp_.x = x
|
||||||
|
return exp_
|
||||||
|
|
||||||
name in ['gaussian noise attack', 'edge perturbation attack', 'pgm', 'fgsd']:wq
|
def _add_edge(self,exp,p:float) -> Explanation:
|
||||||
|
exp_ = copy.copy(exp)
|
||||||
|
exp_.edge_index, _ = add_random_edge(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0])
|
||||||
|
return exp_
|
||||||
|
|
||||||
|
def _remove_edge(self,exp,p:float) -> Explanation:
|
||||||
|
exp_ = copy.copy(exp)
|
||||||
|
exp_.edge_index, _ = dropout_edge(exp_.edge_index,p=p)
|
||||||
|
return exp_
|
||||||
|
|
||||||
|
def _remove_node(self,exp,p:float) -> Explanation:
|
||||||
|
exp_ = copy.copy(exp)
|
||||||
|
exp_.edge_index, _ = dropout_node(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0])
|
||||||
|
return exp_
|
||||||
|
|
||||||
|
def _load_add_edge(self):
|
||||||
|
return lambda exp : self._add_edge(exp,p=self.dropout)
|
||||||
|
|
||||||
|
def _load_remove_edge(self):
|
||||||
|
return lambda exp : self._remove_edge(exp,p=self.dropout)
|
||||||
|
|
||||||
|
def _load_remove_node(self):
|
||||||
|
return lambda exp : self._remove_node(exp,p=self.dropout)
|
||||||
|
|
||||||
|
def _load_gaussian_noise(self):
|
||||||
|
return lambda exp: self._gaussian_noise(exp)
|
||||||
|
|
||||||
|
def _load_metric(self):
|
||||||
|
if name in self.authorized_metric:
|
||||||
|
if name == "gaussian_noise":
|
||||||
|
self.metric= self._load_gaussian_noise()
|
||||||
|
if name == "add_edge":
|
||||||
|
self.metric=self._load_add_edge()
|
||||||
|
if name == "remove_edge":
|
||||||
|
self.metric= self._load_remove_edge()
|
||||||
|
if name == "remove_node":
|
||||||
|
self.metric= self._load_remove_node()
|
||||||
|
if name== "pgd":
|
||||||
|
print('set LOSS with cfg ')
|
||||||
|
pgd = PGD(model=self.model,loss=LOSS)
|
||||||
|
self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1,radius=1, step_num = 50, random_start=False, norm = 'inf')
|
||||||
|
if name== "fgsm":
|
||||||
|
print('set LOSS with cfg ')
|
||||||
|
pgd = FGSM(model=self.model,loss=LOSS)
|
||||||
|
self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{name} is not supported yet')
|
||||||
|
|
||||||
|
return self.metric
|
||||||
|
|
||||||
|
def forward(self,exp) -> Explanation:
|
||||||
|
attack = self.metric(exp)
|
||||||
|
return attack
|
||||||
|
|
|
@ -5,4 +5,5 @@ class Sparsity(Metric):
|
||||||
def __init__(self,name):
|
def __init__(self,name):
|
||||||
super().__init__(name=name,model=None)
|
super().__init__(name=name,model=None)
|
||||||
|
|
||||||
def
|
def forward(self, mask):
|
||||||
|
return torch.mean(mask.float()).item()
|
||||||
|
|
Loading…
Reference in New Issue