Aborting
This commit is contained in:
		
						commit
						45d39edbd9
					
				
					 4 changed files with 204 additions and 19 deletions
				
			
		| 
						 | 
				
			
			@ -2,6 +2,10 @@ import traceback
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from explaining_framework.explaining_framework.metric.accuracy import Accuracy
 | 
			
		||||
from explaining_framework.explaining_framework.metric.fidelity import Fidelity
 | 
			
		||||
from explaining_framework.explaining_framework.metric.robust import Attack
 | 
			
		||||
from explaining_framework.explaining_framework.metric.sparsity import Sparsity
 | 
			
		||||
from torch_geometric.data import Batch, Data
 | 
			
		||||
from torch_geometric.explain import Explainer
 | 
			
		||||
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
 | 
			
		||||
| 
						 | 
				
			
			@ -27,16 +31,16 @@ __all__captum = [
 | 
			
		|||
 | 
			
		||||
__all__graphxai = [
 | 
			
		||||
    "CAM",
 | 
			
		||||
    "GradCAM",
 | 
			
		||||
    "GNN_LRP",
 | 
			
		||||
    "GradExplainer",
 | 
			
		||||
    "GuidedBackPropagation",
 | 
			
		||||
    "IntegratedGradients",
 | 
			
		||||
    "PGExplainer",
 | 
			
		||||
    "PGMExplainer",
 | 
			
		||||
    "RandomExplainer",
 | 
			
		||||
    "SubgraphX",
 | 
			
		||||
    "GraphMASK",
 | 
			
		||||
    #  "GradCAM",
 | 
			
		||||
    #  "GNN_LRP",
 | 
			
		||||
    #  "GradExplainer",
 | 
			
		||||
    #  "GuidedBackPropagation",
 | 
			
		||||
    #  "IntegratedGradients",
 | 
			
		||||
    #  "PGExplainer",
 | 
			
		||||
    #  "PGMExplainer",
 | 
			
		||||
    #  "RandomExplainer",
 | 
			
		||||
    #  "SubgraphX",
 | 
			
		||||
    #  "GraphMASK",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -77,9 +81,8 @@ for epoch in range(1, 2):
 | 
			
		|||
    optimizer.step()
 | 
			
		||||
target = torch.LongTensor([[0]])
 | 
			
		||||
 | 
			
		||||
for kind in ["node", "graph"]:
 | 
			
		||||
    print(kind)
 | 
			
		||||
    for name in __all__captum + __all__graphxai:
 | 
			
		||||
for kind in ["graph"]:
 | 
			
		||||
    for name in __all__graphxai:
 | 
			
		||||
        if name in __all__captum:
 | 
			
		||||
            explaining_algorithm = CaptumWrapper(name)
 | 
			
		||||
        elif name in __all__graphxai:
 | 
			
		||||
| 
						 | 
				
			
			@ -102,6 +105,7 @@ for kind in ["node", "graph"]:
 | 
			
		|||
                    task_level=kind,
 | 
			
		||||
                    return_type="raw",
 | 
			
		||||
                ),
 | 
			
		||||
                threshold_config=dict(threshold_type="hard", value=0.5),
 | 
			
		||||
            )
 | 
			
		||||
            explanation = explainer(
 | 
			
		||||
                x=batch.x,
 | 
			
		||||
| 
						 | 
				
			
			@ -109,10 +113,26 @@ for kind in ["node", "graph"]:
 | 
			
		|||
                index=int(target),
 | 
			
		||||
                target=batch.y,
 | 
			
		||||
            )
 | 
			
		||||
            explanation.__setattr__(
 | 
			
		||||
                "model_prediction", explainer.get_prediction(x, edge_index)
 | 
			
		||||
            # explanation.__setattr__(
 | 
			
		||||
            # "model_prediction", explainer.get_prediction(x, edge_index)
 | 
			
		||||
            # )
 | 
			
		||||
            explanation_threshold = explanation._apply_mask(
 | 
			
		||||
                node_mask=explanation.node_mask, edge_mask=explanation.edge_mask
 | 
			
		||||
            )
 | 
			
		||||
            print(explanation.__dict__)
 | 
			
		||||
 | 
			
		||||
            for f_name in [
 | 
			
		||||
                "precision_score",
 | 
			
		||||
                "precision_score",
 | 
			
		||||
                "jaccard_score",
 | 
			
		||||
                "roc_auc_score",
 | 
			
		||||
                "f1_score",
 | 
			
		||||
                "accuracy_score",
 | 
			
		||||
            ]:
 | 
			
		||||
                acc = Accuracy(f_name)
 | 
			
		||||
                gt = torch.ones_like(x) / 2
 | 
			
		||||
                out = acc.forward(mask=explanation_threshold.node_mask, target=gt)
 | 
			
		||||
                print(out)
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            # print(str(e))
 | 
			
		||||
            pass
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,6 +10,8 @@ NUM_CLASS = 5
 | 
			
		|||
 | 
			
		||||
def softmax(data):
 | 
			
		||||
    return Softmax(dim=1)(data)
 | 
			
		||||
def kl(data1,data2):
 | 
			
		||||
    return KLDivLoss(dim=1)(data1,data2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Fidelity(Metric):
 | 
			
		||||
| 
						 | 
				
			
			@ -23,12 +25,12 @@ class Fidelity(Metric):
 | 
			
		|||
            "infidelity_KL",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        self.metric = self.load_metric(name)
 | 
			
		||||
        self.exp_sub = None
 | 
			
		||||
        self.exp_sub_c = None
 | 
			
		||||
        self.s_exp_sub = None
 | 
			
		||||
        self.s_exp_sub_c = None
 | 
			
		||||
        self.s_initial_data = None
 | 
			
		||||
        self.metric = self.load_metric(name)
 | 
			
		||||
 | 
			
		||||
    def _score_check(self):
 | 
			
		||||
        if any(
 | 
			
		||||
| 
						 | 
				
			
			@ -90,6 +92,14 @@ class Fidelity(Metric):
 | 
			
		|||
            torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
 | 
			
		||||
        ).item()
 | 
			
		||||
 | 
			
		||||
    def _infidelity_KL(self, exp:Explanation) -> float:
 | 
			
		||||
       self._score_check()
 | 
			
		||||
       prob_initial = softmax(self.s_initial_data)
 | 
			
		||||
       prob_exp = softmax(self.s_exp_sub)
 | 
			
		||||
       return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def score(self, exp):
 | 
			
		||||
        self.exp_sub = exp.get_explanation_subgraph()
 | 
			
		||||
        self.exp_sub_c = exp.get_complement_subgraph()
 | 
			
		||||
| 
						 | 
				
			
			@ -107,9 +117,12 @@ class Fidelity(Metric):
 | 
			
		|||
                self.metric = lambda exp: self._fidelity_plus_prob(exp)
 | 
			
		||||
            if name == "fidelity_minus_prob":
 | 
			
		||||
                self.metric = lambda exp: self._fidelity_minus_prob(exp)
 | 
			
		||||
            if name == "_infidelity_KL":
 | 
			
		||||
                self.metric = lambda exp: self._infidelity_KL(exp)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"{name} is not supported")
 | 
			
		||||
        return self.metric
 | 
			
		||||
 | 
			
		||||
    def __call__(self, exp: Explanation):
 | 
			
		||||
        self.score(exp)
 | 
			
		||||
        return self.metric(exp)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,155 @@
 | 
			
		|||
import copy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_gradient(model,input,target, loss):
 | 
			
		||||
    with torch.autograd.set_grad_enabled(True):
 | 
			
		||||
        out = model(input)
 | 
			
		||||
        err = loss(out,target)
 | 
			
		||||
        return torch.autograd.grad(err,input)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FGSM(object):
 | 
			
		||||
    def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.loss = loss
 | 
			
		||||
        self.lower_bound = lower_bound
 | 
			
		||||
        self.upper_bound = upper_bound
 | 
			
		||||
        self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
 | 
			
		||||
        self.zero_thresh = 10**-6
 | 
			
		||||
 | 
			
		||||
    def forward(self, input, target,  epsilon:float) -> Explanation:
 | 
			
		||||
        grad = compute_gradient(model=self.model,input=input, target=target, loss=self.loss)
 | 
			
		||||
        grad = self.bound(grad)
 | 
			
		||||
        out = torch.where(torch.abs(grad) > self.zero_thresh,input - epsilon * torch.sign(grad),input)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PGD(object):
 | 
			
		||||
    def __init__(self,model: torch.nn.Module,loss: torch.nn.Module,lower_bound: float = float("-inf"), upper_bound: float = float("inf")):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.loss = loss
 | 
			
		||||
        self.lower_bound = lower_bound
 | 
			
		||||
        self.upper_bound = upper_bound
 | 
			
		||||
        self.bound = lambda x: torch.clamp(x, min=lower_bound, max=upper_bound)
 | 
			
		||||
        self.zero_thresh = 10**-6
 | 
			
		||||
        self.fgsm = FGSM(model=model,loss=loss,lower_bound=lower_bound,upper_bound=upper_bound)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input, target, epsilon:float, radius:float, step_num:int, random_start:bool = False, norm:str='inf') -> Explanation:
 | 
			
		||||
        diff = outputs - inputs
 | 
			
		||||
        if norm == "inf":
 | 
			
		||||
            return inputs + torch.clamp(diff, -radius, radius)
 | 
			
		||||
        elif norm == "2":
 | 
			
		||||
            return inputs + torch.renorm(diff, 2, 0, radius)
 | 
			
		||||
        else:
 | 
			
		||||
            raise AssertionError("Norm constraint must be 2 or inf.")
 | 
			
		||||
 | 
			
		||||
        perturbed_inputs = input
 | 
			
		||||
        if random_start:
 | 
			
		||||
            perturbed_inputs= self.bound(self._random_point(input, radius, norm))
 | 
			
		||||
        for _ in range(step_num):
 | 
			
		||||
            perturbed_inputs = self.fgsm.perturb(
 | 
			
		||||
                input=perturbed_inputs, epsilon=epsilon, target=target
 | 
			
		||||
            )
 | 
			
		||||
            perturbed_inputs =  self.forward(input, perturbed_inputs)
 | 
			
		||||
            perturbed_inputs = self.bound(perturbed_inputs).detach()
 | 
			
		||||
        return perturbed_inputs
 | 
			
		||||
 | 
			
		||||
    def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor:
 | 
			
		||||
        r"""
 | 
			
		||||
        A helper function that returns a uniform random point within the ball
 | 
			
		||||
        with the given center and radius. Norm should be either L2 or Linf.
 | 
			
		||||
        """
 | 
			
		||||
        if norm == "2":
 | 
			
		||||
            u = torch.randn_like(center)
 | 
			
		||||
            unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size())
 | 
			
		||||
            d = torch.numel(center[0])
 | 
			
		||||
            r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius
 | 
			
		||||
            r = r[(...,) + (None,) * (r.dim() - 1)]
 | 
			
		||||
            x = r * unit_u
 | 
			
		||||
            return center + x
 | 
			
		||||
        elif norm == "inf":
 | 
			
		||||
            x = torch.rand_like(center) * radius * 2 - radius
 | 
			
		||||
            return center + x
 | 
			
		||||
        else:
 | 
			
		||||
            raise AssertionError("Norm constraint must be L2 or Linf.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Attack(Metric):
 | 
			
		||||
    def __init__(name: str, model: torch.nn.Module, dropout:float = 0.5):
 | 
			
		||||
        super().__init__(name=name, model=model)
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.authorized_metric = [
 | 
			
		||||
            "gaussian_noise",
 | 
			
		||||
            "add_edge",
 | 
			
		||||
            "remove_edge",
 | 
			
		||||
            "remove_node"
 | 
			
		||||
            "pgd",
 | 
			
		||||
            "fgsm",
 | 
			
		||||
        ]
 | 
			
		||||
        self.dropout = dropout
 | 
			
		||||
        self._load_metric(name)
 | 
			
		||||
 | 
			
		||||
    def _gaussian_noise(self,exp) -> Explanation:
 | 
			
		||||
        x= torch.clone(exp.x)
 | 
			
		||||
        x=x+torch.randn(*x.shape)
 | 
			
		||||
        exp_ = copy.copy(exp)
 | 
			
		||||
        exp_.x = x
 | 
			
		||||
        return exp_
 | 
			
		||||
 | 
			
		||||
    name in ['gaussian noise attack', 'edge perturbation attack', 'pgm', 'fgsd']:wq
 | 
			
		||||
    def _add_edge(self,exp,p:float) -> Explanation:
 | 
			
		||||
        exp_ = copy.copy(exp)
 | 
			
		||||
        exp_.edge_index, _ = add_random_edge(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0])
 | 
			
		||||
        return exp_
 | 
			
		||||
 | 
			
		||||
    def _remove_edge(self,exp,p:float) -> Explanation:
 | 
			
		||||
        exp_ = copy.copy(exp)
 | 
			
		||||
        exp_.edge_index, _ = dropout_edge(exp_.edge_index,p=p)
 | 
			
		||||
        return exp_
 | 
			
		||||
 | 
			
		||||
    def _remove_node(self,exp,p:float) -> Explanation:
 | 
			
		||||
        exp_ = copy.copy(exp)
 | 
			
		||||
        exp_.edge_index, _ = dropout_node(exp_.edge_index,p=p,num_nodes=exp_.x.shape[0])
 | 
			
		||||
        return exp_
 | 
			
		||||
 | 
			
		||||
    def _load_add_edge(self):
 | 
			
		||||
        return lambda exp : self._add_edge(exp,p=self.dropout)
 | 
			
		||||
 | 
			
		||||
    def _load_remove_edge(self):
 | 
			
		||||
        return lambda exp : self._remove_edge(exp,p=self.dropout)
 | 
			
		||||
 | 
			
		||||
    def _load_remove_node(self):
 | 
			
		||||
        return lambda exp : self._remove_node(exp,p=self.dropout)
 | 
			
		||||
 | 
			
		||||
    def _load_gaussian_noise(self):
 | 
			
		||||
        return lambda exp: self._gaussian_noise(exp)
 | 
			
		||||
 | 
			
		||||
    def _load_metric(self):
 | 
			
		||||
        if name in self.authorized_metric:
 | 
			
		||||
            if name == "gaussian_noise":
 | 
			
		||||
                self.metric= self._load_gaussian_noise()
 | 
			
		||||
            if name == "add_edge":
 | 
			
		||||
                self.metric=self._load_add_edge()
 | 
			
		||||
            if name == "remove_edge":
 | 
			
		||||
                self.metric= self._load_remove_edge()
 | 
			
		||||
            if name == "remove_node":
 | 
			
		||||
                self.metric= self._load_remove_node()
 | 
			
		||||
            if name== "pgd":
 | 
			
		||||
                print('set LOSS with cfg ')
 | 
			
		||||
                pgd = PGD(model=self.model,loss=LOSS)
 | 
			
		||||
                self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1,radius=1, step_num = 50, random_start=False, norm = 'inf')
 | 
			
		||||
            if name== "fgsm":
 | 
			
		||||
                print('set LOSS with cfg ')
 | 
			
		||||
                pgd = FGSM(model=self.model,loss=LOSS)
 | 
			
		||||
                self.metric = lambda exp:pgd.forward(input=exp,target=exp.y,epsilon=1)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f'{name} is not supported yet')
 | 
			
		||||
 | 
			
		||||
        return self.metric
 | 
			
		||||
 | 
			
		||||
    def forward(self,exp) -> Explanation:
 | 
			
		||||
        attack = self.metric(exp)
 | 
			
		||||
        return attack
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,4 +5,5 @@ class Sparsity(Metric):
 | 
			
		|||
    def __init__(self,name):
 | 
			
		||||
        super().__init__(name=name,model=None)
 | 
			
		||||
 | 
			
		||||
    def 
 | 
			
		||||
    def forward(self, mask):
 | 
			
		||||
        return torch.mean(mask.float()).item()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue