239 lines
7.5 KiB
Python
239 lines
7.5 KiB
Python
import copy
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import CrossEntropyLoss, MSELoss
|
|
from torch_geometric.explain.explanation import Explanation
|
|
from torch_geometric.graphgym.config import cfg
|
|
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
|
|
|
from explaining_framework.metric.base import Metric
|
|
|
|
|
|
def compute_gradient(model, inp, target, loss):
|
|
with torch.autograd.set_grad_enabled(True):
|
|
inp.x.requires_grad = True
|
|
out = model(x=inp.x, edge_index=inp.edge_index)
|
|
err = loss(out, target)
|
|
return torch.autograd.grad(err, inp.x)[0]
|
|
|
|
|
|
class FGSM(Metric):
|
|
def __init__(
|
|
self,
|
|
model: torch.nn.Module,
|
|
loss: torch.nn.Module,
|
|
lower_bound: float = float("-inf"),
|
|
upper_bound: float = float("inf"),
|
|
):
|
|
super().__init__(name="fgsm", model=model)
|
|
self.model = model
|
|
self.loss = loss
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
|
|
self.bound = lambda x: torch.clamp(
|
|
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
|
|
)
|
|
|
|
self.zero_thresh = 10**-6
|
|
|
|
def forward(self, input, target, epsilon: float) -> Explanation:
|
|
input_ = input.clone()
|
|
grad = compute_gradient(
|
|
model=self.model, inp=input_, target=target, loss=self.loss
|
|
)
|
|
grad = self.bound(grad)
|
|
input_.x = torch.where(
|
|
torch.abs(grad) > self.zero_thresh,
|
|
input_.x - epsilon * torch.sign(grad),
|
|
input_.x,
|
|
)
|
|
return input_
|
|
|
|
def load_metric(self):
|
|
pass
|
|
|
|
|
|
class PGD(Metric):
|
|
def __init__(
|
|
self,
|
|
model: torch.nn.Module,
|
|
loss: torch.nn.Module,
|
|
lower_bound: float = float("-inf"),
|
|
upper_bound: float = float("inf"),
|
|
):
|
|
super().__init__(name="pgd", model=model)
|
|
self.model = model
|
|
self.loss = loss
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
self.bound = lambda x: torch.clamp(
|
|
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
|
|
)
|
|
self.zero_thresh = 10**-6
|
|
self.fgsm = FGSM(
|
|
model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input,
|
|
target,
|
|
epsilon: float,
|
|
radius: float,
|
|
step_num: int,
|
|
random_start: bool = False,
|
|
norm: str = "inf",
|
|
) -> Explanation:
|
|
def _clip(inputs: Explanation, outputs: Explanation) -> Explanation:
|
|
diff = outputs.x - inputs.x
|
|
if norm == "inf":
|
|
inputs.x = inputs.x + torch.clamp(diff, -radius, radius)
|
|
return inputs
|
|
elif norm == "2":
|
|
inputs.x = inputs.x + torch.renorm(diff, 2, 0, radius)
|
|
return inputs
|
|
else:
|
|
raise AssertionError("Norm constraint must be L2 or Linf.")
|
|
|
|
perturbed_input = input
|
|
if random_start:
|
|
perturbed_input = self.bound(self._random_point(input.x, radius, norm))
|
|
for _ in range(step_num):
|
|
perturbed_input = self.fgsm.forward(
|
|
input=perturbed_input, epsilon=epsilon, target=target
|
|
)
|
|
perturbed_input = _clip(input, perturbed_input)
|
|
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
|
return perturbed_input
|
|
|
|
def load_metric(self):
|
|
pass
|
|
|
|
def _random_point(
|
|
self, center: torch.Tensor, radius: float, norm: str
|
|
) -> torch.Tensor:
|
|
r"""
|
|
A helper function that returns a uniform random point within the ball
|
|
with the given center and radius. Norm should be either L2 or Linf.
|
|
"""
|
|
if norm == "2":
|
|
u = torch.randn_like(center)
|
|
unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size())
|
|
d = torch.numel(center[0])
|
|
r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius
|
|
r = r[(...,) + (None,) * (r.dim() - 1)]
|
|
x = r * unit_u
|
|
return center + x
|
|
elif norm == "inf":
|
|
x = torch.rand_like(center) * radius * 2 - radius
|
|
return center + x
|
|
else:
|
|
raise AssertionError("Norm constraint must be L2 or Linf.")
|
|
|
|
|
|
class Attack(Metric):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
model: torch.nn.Module,
|
|
dropout: float = 0.5,
|
|
loss: torch.nn = None,
|
|
):
|
|
|
|
super().__init__(name=name, model=model)
|
|
self.name = name
|
|
self.model = model
|
|
self.authorized_metric = [
|
|
"gaussian_noise",
|
|
"add_edge",
|
|
"remove_edge",
|
|
"remove_node",
|
|
"pgd",
|
|
"fgsm",
|
|
]
|
|
self.dropout = dropout
|
|
if loss is None:
|
|
if cfg.model.loss_fun == "cross_entropy":
|
|
self.loss = CrossEntropyLoss()
|
|
elif cfg.model.loss_fun == "mse":
|
|
self.loss = MSELoss()
|
|
else:
|
|
raise ValueError(f"{loss} is not supported yet")
|
|
else:
|
|
self.loss = loss
|
|
self.load_metric(name)
|
|
|
|
def _gaussian_noise(self, exp) -> Explanation:
|
|
x = torch.clone(exp.x)
|
|
x = x + torch.randn(*x.shape)
|
|
exp_ = exp.clone()
|
|
exp_.x = x
|
|
return exp_
|
|
|
|
def _add_edge(self, exp, p: float) -> Explanation:
|
|
exp_ = exp.clone()
|
|
exp_.edge_index, _ = add_random_edge(
|
|
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
|
)
|
|
return exp_
|
|
|
|
def _remove_edge(self, exp, p: float) -> Explanation:
|
|
exp_ = exp.clone()
|
|
exp_.edge_index, _ = dropout_edge(exp_.edge_index, p=p)
|
|
return exp_
|
|
|
|
def _remove_node(self, exp, p: float) -> Explanation:
|
|
exp_ = exp.clone()
|
|
exp_.edge_index, _, _ = dropout_node(
|
|
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
|
)
|
|
return exp_
|
|
|
|
def _load_add_edge(self):
|
|
return lambda exp: self._add_edge(exp, p=self.dropout)
|
|
|
|
def _load_remove_edge(self):
|
|
return lambda exp: self._remove_edge(exp, p=self.dropout)
|
|
|
|
def _load_remove_node(self):
|
|
return lambda exp: self._remove_node(exp, p=self.dropout)
|
|
|
|
def _load_gaussian_noise(self):
|
|
return lambda exp: self._gaussian_noise(exp)
|
|
|
|
def load_metric(self, name):
|
|
if name in self.authorized_metric:
|
|
if name == "gaussian_noise":
|
|
self.metric = self._load_gaussian_noise()
|
|
if name == "add_edge":
|
|
self.metric = self._load_add_edge()
|
|
if name == "remove_edge":
|
|
self.metric = self._load_remove_edge()
|
|
if name == "remove_node":
|
|
self.metric = self._load_remove_node()
|
|
if name == "pgd":
|
|
pgd = PGD(model=self.model, loss=self.loss)
|
|
self.metric = lambda exp: pgd.forward(
|
|
input=exp,
|
|
target=exp.y,
|
|
epsilon=1,
|
|
radius=1,
|
|
step_num=50,
|
|
random_start=False,
|
|
norm="inf",
|
|
)
|
|
if name == "fgsm":
|
|
fgsm = FGSM(model=self.model, loss=self.loss)
|
|
self.metric = lambda exp: fgsm.forward(
|
|
input=exp, target=exp.y, epsilon=1
|
|
)
|
|
else:
|
|
raise ValueError(f"{name} is not supported yet")
|
|
|
|
return self.metric
|
|
|
|
def forward(self, exp) -> Explanation:
|
|
attack = self.metric(exp)
|
|
return attack
|