New fixes and new features

This commit is contained in:
araison 2022-12-14 18:28:12 +01:00
parent ea0a5dd86e
commit 45730328a7
16 changed files with 155 additions and 134 deletions

View File

@ -188,21 +188,21 @@ class CaptumWrapper(ExplainerAlgorithm):
def _parse_attr(self, attr):
if self.mask_type == "node":
node_mask = attr[0].squeeze(0)
node_feat_mask = attr[0].squeeze(0)
edge_mask = None
if self.mask_type == "edge":
node_mask = None
node_feat_mask = None
edge_mask = attr[0]
if self.mask_type == "node_and_edge":
node_mask = attr[0].squeeze(0)
node_feat_mask = attr[0].squeeze(0)
edge_mask = attr[1]
else:
raise ValueError
edge_feat_mask = None
node_feat_mask = None
node_mask = None
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
@ -238,6 +238,7 @@ class CaptumWrapper(ExplainerAlgorithm):
return Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,

View File

@ -14,7 +14,7 @@ from graphxai.explainers.pgm_explainer import PGMExplainer
from graphxai.explainers.random import RandomExplainer
from graphxai.explainers.subgraphx import SubgraphX
from torch import Tensor
from torch.nn import CrossEntropyLoss, KLDivLoss, MSELoss
from torch.nn import CrossEntropyLoss, MSELoss
from torch_geometric.data import Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
@ -260,6 +260,7 @@ class GraphXAIWrapper(ExplainerAlgorithm):
return Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,

View File

@ -0,0 +1,29 @@
import sklearn.metrics
import torch
from base import Metric
class Accuracy(Metric):
def __init__(name: str):
super().__init__(name=name, model=None)
self.authorized_metric = [
"precision_score",
"precision_score",
"jaccard_score",
"roc_auc_score",
"f1_score",
"accuracy_score",
]
self.metric = self.load_metric(name)
def load_metric(name):
if name in self.authorized_metric:
self.metric = eval("sklearn.metric.{name}")
else:
raise ValueError(f"{name} is not supported")
def forward(self, mask, target: Tensor) -> float:
if mask.type() == torch.bool and target.type() == torch.bool:
return self.metric(y_pred=mask, y_true=target)

View File

@ -7,6 +7,7 @@ class Metric(ABC):
self.model = model
if is_model_needed and model is None:
raise ValueError(f"{self.name} needs model to perform measurements")
self.authorized_metric = None
def is_model_needed(self):
if "fidelity" in self.name:
@ -14,6 +15,31 @@ class Metric(ABC):
else:
return False
@abstractmethod
def load_metric(name: str):
pass
@abstractmethod
def __call__(self, exp: Explanation, **kwargs) -> float:
pass
def get_prediction(self, *args, **kwargs) -> torch.Tensor:
r"""Returns the prediction of the model on the input graph.
If the model mode is :obj:`"regression"`, the prediction is returned as
a scalar value.
If the model mode :obj:`"classification"`, the prediction is returned
as the predicted class label.
Args:
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the
model.
"""
training = self.model.training
self.model.eval()
with torch.no_grad():
out = self.model(*args, **kwargs)
self.model.train(training)
return out

View File

@ -0,0 +1,66 @@
import torch
from torch.nn import KLDivLoss, Softmax
from base import Metric
class Fidelity(Metric):
def __init__(name: str, model: torch.nn.Module, mask_type: str):
super().__init__(name=name, model=model)
self.authorized_metric = [
"fidelity_plus",
"fidelity_minus",
"fidelity_plus_prob",
"fidelity_minus_prob",
"fidelity_plus_model",
"fidelity_minus_model",
"fidelity_plus_prob_model",
"fidelity_minus_prob_model",
"infidelity_KL",
]
self.metric = self.load_metric(name)
self.exp_sub = None
self.exp_sub_c = None
self.s_exp_sub = None
self.s_exp_sub_c = None
self.s_initial_data = None
def _fidelity_plus(self, exp) -> float:
if any(
[
attr is None
for attr in [
self.exp_sub,
self.exp_sub_c,
self.s_exp_sub,
self.s_exp_sub_c,
self.s_initial_data,
]
]
):
self.score(exp)
else:
fid = self.s_initial_data - self.s_exp_sub_c
def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph()
self.exp_sub_c = exp.get_complement_subgraph()
self.s_exp_sub = self.get_prediction(self.exp_sub)
self.s_exp_sub_c = self.get_prediction(self.exp_sub_c)
self.s_initial_data = self.get_prediction(exp)
def load_metric(name):
if name in self.authorized_metric:
if name == "fidelity_plus":
self.metric = eval("sklearn.metric.{name}")
else:
raise ValueError(f"{name} is not supported")
def compute(self, mask, target: Tensor) -> float:
if mask.type() == torch.bool and target.type() == torch.bool:
return self.metric(y_pred=mask, y_true=target)
def __call__(self, exp: Explanation):
pass

View File

@ -0,0 +1,4 @@
class Attack(Metric):
name in ['gaussian noise attack', 'edge perturbation attack', 'pgm', 'fgsd']:wq

View File

@ -1,3 +0,0 @@
class BaseExplaining(object):
def __init__(self,model,explainer_name:wq

View File

@ -1 +0,0 @@
from torch_geometric.nn.models.captum import CaptumModel

View File

@ -0,0 +1,16 @@
import copy
from torch import FloatTensor
from torch.nn import ReLU
def relu_mask(explanation: Explanation) -> Explanation:
relu = ReLU()
explanation_store = explanation._store
raw_data = copy.copy(explanation._store)
for k, v in explanation_store.items():
if "mask" in k:
explanation_store[k] = relu(v)
explanation.__setattr__("raw_explanation", raw_data)
explanation.__setattr__("raw_explanation_transform", "relu")
return explanation

View File

@ -0,0 +1,7 @@
import copy
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
from torch_geometric.explain.explanation import Explanation

View File

View File

@ -1,125 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import multiprocessing as mp
import os
import threading
import time
import types
from inspect import getmembers, isfunction, signature
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
with open("mapping_nx.txt", "r") as file:
BLACK_LIST = [line.rstrip() for line in file]
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx.algorithms, isfunction)
MANUALLY_ADDED = [
"algebraic_connectivity",
"adjacency_spectrum",
"degree",
"density",
"laplacian_spectrum",
"normalized_laplacian_spectrum",
"number_of_selfloops",
"number_of_edges",
"number_of_nodes",
]
MANUALLY_ADDED_LIST = [
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
]
functions_list = functions_list + MANUALLY_ADDED_LIST
maps = {}
for func in functions_list:
name, f = func
if (
name in BLACK_LIST
or name == "recursive_simple_cycles"
or "triad" in name
or "weisfeiler" in name
or "dfs" in name
or "trophic" in name
or "recursive" in name
or "scipy" in name
or "numpy" in name
or "sigma" in name
or "omega" in name
or "all_" in name
):
continue
else:
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
"num_nodes_features",
"y",
]
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps
def __call__(self, data):
data_ = data.__copy__()
stats = {}
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, f in v.items():
if f is None:
stats[name] = None
continue
else:
try:
t0 = time.time()
val = f(_data_)
t1 = time.time()
delta = t1 - t0
except Exception as e:
print(name, e)
with open(f"{name}.txt", "w") as f:
f.write(str(e))
# print(name, round(delta, 4))
# if callable(val) and k == "torch_geometric":
# val = val()
# if isinstance(val, types.GeneratorType):
# val = list(val)
# stats[name] = val
return stats
from torch_geometric.datasets import KarateClub, Planetoid
d = Planetoid(root="/tmp/", name="Cora")
# d = KarateClub()
a = d[0]
st = GraphStat()
stat = st(a)
for k, v in stat.items():
print("---------")
print("Name:", k)
print("Type:", type(v))
print("Val:", v)
print("---------")

View File