New fixes and new features
This commit is contained in:
parent
ea0a5dd86e
commit
45730328a7
16 changed files with 155 additions and 134 deletions
|
@ -188,21 +188,21 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
|
|
||||||
def _parse_attr(self, attr):
|
def _parse_attr(self, attr):
|
||||||
if self.mask_type == "node":
|
if self.mask_type == "node":
|
||||||
node_mask = attr[0].squeeze(0)
|
node_feat_mask = attr[0].squeeze(0)
|
||||||
edge_mask = None
|
edge_mask = None
|
||||||
|
|
||||||
if self.mask_type == "edge":
|
if self.mask_type == "edge":
|
||||||
node_mask = None
|
node_feat_mask = None
|
||||||
edge_mask = attr[0]
|
edge_mask = attr[0]
|
||||||
|
|
||||||
if self.mask_type == "node_and_edge":
|
if self.mask_type == "node_and_edge":
|
||||||
node_mask = attr[0].squeeze(0)
|
node_feat_mask = attr[0].squeeze(0)
|
||||||
edge_mask = attr[1]
|
edge_mask = attr[1]
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
edge_feat_mask = None
|
edge_feat_mask = None
|
||||||
node_feat_mask = None
|
node_mask = None
|
||||||
|
|
||||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
|
@ -238,6 +238,7 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
return Explanation(
|
return Explanation(
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
|
y=target,
|
||||||
edge_mask=edge_mask,
|
edge_mask=edge_mask,
|
||||||
node_mask=node_mask,
|
node_mask=node_mask,
|
||||||
node_feat_mask=node_feat_mask,
|
node_feat_mask=node_feat_mask,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from graphxai.explainers.pgm_explainer import PGMExplainer
|
||||||
from graphxai.explainers.random import RandomExplainer
|
from graphxai.explainers.random import RandomExplainer
|
||||||
from graphxai.explainers.subgraphx import SubgraphX
|
from graphxai.explainers.subgraphx import SubgraphX
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import CrossEntropyLoss, KLDivLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.explain import Explanation
|
from torch_geometric.explain import Explanation
|
||||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
@ -260,6 +260,7 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
return Explanation(
|
return Explanation(
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
|
y=target,
|
||||||
edge_mask=edge_mask,
|
edge_mask=edge_mask,
|
||||||
node_mask=node_mask,
|
node_mask=node_mask,
|
||||||
node_feat_mask=node_feat_mask,
|
node_feat_mask=node_feat_mask,
|
||||||
|
|
29
explaining_framework/metric/accuracy.py
Normal file
29
explaining_framework/metric/accuracy.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import sklearn.metrics
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from base import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class Accuracy(Metric):
|
||||||
|
def __init__(name: str):
|
||||||
|
super().__init__(name=name, model=None)
|
||||||
|
self.authorized_metric = [
|
||||||
|
"precision_score",
|
||||||
|
"precision_score",
|
||||||
|
"jaccard_score",
|
||||||
|
"roc_auc_score",
|
||||||
|
"f1_score",
|
||||||
|
"accuracy_score",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.metric = self.load_metric(name)
|
||||||
|
|
||||||
|
def load_metric(name):
|
||||||
|
if name in self.authorized_metric:
|
||||||
|
self.metric = eval("sklearn.metric.{name}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name} is not supported")
|
||||||
|
|
||||||
|
def forward(self, mask, target: Tensor) -> float:
|
||||||
|
if mask.type() == torch.bool and target.type() == torch.bool:
|
||||||
|
return self.metric(y_pred=mask, y_true=target)
|
|
@ -7,6 +7,7 @@ class Metric(ABC):
|
||||||
self.model = model
|
self.model = model
|
||||||
if is_model_needed and model is None:
|
if is_model_needed and model is None:
|
||||||
raise ValueError(f"{self.name} needs model to perform measurements")
|
raise ValueError(f"{self.name} needs model to perform measurements")
|
||||||
|
self.authorized_metric = None
|
||||||
|
|
||||||
def is_model_needed(self):
|
def is_model_needed(self):
|
||||||
if "fidelity" in self.name:
|
if "fidelity" in self.name:
|
||||||
|
@ -14,6 +15,31 @@ class Metric(ABC):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_metric(name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, exp: Explanation, **kwargs) -> float:
|
def __call__(self, exp: Explanation, **kwargs) -> float:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_prediction(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
r"""Returns the prediction of the model on the input graph.
|
||||||
|
If the model mode is :obj:`"regression"`, the prediction is returned as
|
||||||
|
a scalar value.
|
||||||
|
If the model mode :obj:`"classification"`, the prediction is returned
|
||||||
|
as the predicted class label.
|
||||||
|
Args:
|
||||||
|
*args: Arguments passed to the model.
|
||||||
|
**kwargs (optional): Additional keyword arguments passed to the
|
||||||
|
model.
|
||||||
|
"""
|
||||||
|
training = self.model.training
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out = self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
self.model.train(training)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
66
explaining_framework/metric/fidelity.py
Normal file
66
explaining_framework/metric/fidelity.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn import KLDivLoss, Softmax
|
||||||
|
|
||||||
|
from base import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class Fidelity(Metric):
|
||||||
|
def __init__(name: str, model: torch.nn.Module, mask_type: str):
|
||||||
|
super().__init__(name=name, model=model)
|
||||||
|
self.authorized_metric = [
|
||||||
|
"fidelity_plus",
|
||||||
|
"fidelity_minus",
|
||||||
|
"fidelity_plus_prob",
|
||||||
|
"fidelity_minus_prob",
|
||||||
|
"fidelity_plus_model",
|
||||||
|
"fidelity_minus_model",
|
||||||
|
"fidelity_plus_prob_model",
|
||||||
|
"fidelity_minus_prob_model",
|
||||||
|
"infidelity_KL",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.metric = self.load_metric(name)
|
||||||
|
self.exp_sub = None
|
||||||
|
self.exp_sub_c = None
|
||||||
|
self.s_exp_sub = None
|
||||||
|
self.s_exp_sub_c = None
|
||||||
|
self.s_initial_data = None
|
||||||
|
|
||||||
|
def _fidelity_plus(self, exp) -> float:
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
attr is None
|
||||||
|
for attr in [
|
||||||
|
self.exp_sub,
|
||||||
|
self.exp_sub_c,
|
||||||
|
self.s_exp_sub,
|
||||||
|
self.s_exp_sub_c,
|
||||||
|
self.s_initial_data,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
):
|
||||||
|
self.score(exp)
|
||||||
|
else:
|
||||||
|
fid = self.s_initial_data - self.s_exp_sub_c
|
||||||
|
|
||||||
|
def score(self, exp):
|
||||||
|
self.exp_sub = exp.get_explanation_subgraph()
|
||||||
|
self.exp_sub_c = exp.get_complement_subgraph()
|
||||||
|
self.s_exp_sub = self.get_prediction(self.exp_sub)
|
||||||
|
self.s_exp_sub_c = self.get_prediction(self.exp_sub_c)
|
||||||
|
self.s_initial_data = self.get_prediction(exp)
|
||||||
|
|
||||||
|
def load_metric(name):
|
||||||
|
if name in self.authorized_metric:
|
||||||
|
if name == "fidelity_plus":
|
||||||
|
|
||||||
|
self.metric = eval("sklearn.metric.{name}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name} is not supported")
|
||||||
|
|
||||||
|
def compute(self, mask, target: Tensor) -> float:
|
||||||
|
if mask.type() == torch.bool and target.type() == torch.bool:
|
||||||
|
return self.metric(y_pred=mask, y_true=target)
|
||||||
|
|
||||||
|
def __call__(self, exp: Explanation):
|
||||||
|
pass
|
4
explaining_framework/metric/robust.py
Normal file
4
explaining_framework/metric/robust.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
class Attack(Metric):
|
||||||
|
|
||||||
|
|
||||||
|
name in ['gaussian noise attack', 'edge perturbation attack', 'pgm', 'fgsd']:wq
|
|
@ -1,3 +0,0 @@
|
||||||
class BaseExplaining(object):
|
|
||||||
def __init__(self,model,explainer_name:wq
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
from torch_geometric.nn.models.captum import CaptumModel
|
|
16
explaining_framework/utils/explanation_adjust.py
Normal file
16
explaining_framework/utils/explanation_adjust.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from torch import FloatTensor
|
||||||
|
from torch.nn import ReLU
|
||||||
|
|
||||||
|
|
||||||
|
def relu_mask(explanation: Explanation) -> Explanation:
|
||||||
|
relu = ReLU()
|
||||||
|
explanation_store = explanation._store
|
||||||
|
raw_data = copy.copy(explanation._store)
|
||||||
|
for k, v in explanation_store.items():
|
||||||
|
if "mask" in k:
|
||||||
|
explanation_store[k] = relu(v)
|
||||||
|
explanation.__setattr__("raw_explanation", raw_data)
|
||||||
|
explanation.__setattr__("raw_explanation_transform", "relu")
|
||||||
|
return explanation
|
7
explaining_framework/utils/explanation_threshold.py
Normal file
7
explaining_framework/utils/explanation_threshold.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
import copy
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
125
utils/stat.py
125
utils/stat.py
|
@ -1,125 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
from inspect import getmembers, isfunction, signature
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
from torch_geometric.data import Data
|
|
||||||
from torch_geometric.utils import to_networkx
|
|
||||||
|
|
||||||
with open("mapping_nx.txt", "r") as file:
|
|
||||||
BLACK_LIST = [line.rstrip() for line in file]
|
|
||||||
|
|
||||||
|
|
||||||
class GraphStat(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.maps = {
|
|
||||||
"networkx": self.available_map_networkx(),
|
|
||||||
"torch_geometric": self.available_map_torch_geometric(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def available_map_networkx(self):
|
|
||||||
functions_list = getmembers(nx.algorithms, isfunction)
|
|
||||||
MANUALLY_ADDED = [
|
|
||||||
"algebraic_connectivity",
|
|
||||||
"adjacency_spectrum",
|
|
||||||
"degree",
|
|
||||||
"density",
|
|
||||||
"laplacian_spectrum",
|
|
||||||
"normalized_laplacian_spectrum",
|
|
||||||
"number_of_selfloops",
|
|
||||||
"number_of_edges",
|
|
||||||
"number_of_nodes",
|
|
||||||
]
|
|
||||||
MANUALLY_ADDED_LIST = [
|
|
||||||
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
|
|
||||||
]
|
|
||||||
functions_list = functions_list + MANUALLY_ADDED_LIST
|
|
||||||
maps = {}
|
|
||||||
for func in functions_list:
|
|
||||||
name, f = func
|
|
||||||
if (
|
|
||||||
name in BLACK_LIST
|
|
||||||
or name == "recursive_simple_cycles"
|
|
||||||
or "triad" in name
|
|
||||||
or "weisfeiler" in name
|
|
||||||
or "dfs" in name
|
|
||||||
or "trophic" in name
|
|
||||||
or "recursive" in name
|
|
||||||
or "scipy" in name
|
|
||||||
or "numpy" in name
|
|
||||||
or "sigma" in name
|
|
||||||
or "omega" in name
|
|
||||||
or "all_" in name
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
maps[name] = f
|
|
||||||
return maps
|
|
||||||
|
|
||||||
def available_map_torch_geometric(self):
|
|
||||||
names = [
|
|
||||||
"num_nodes",
|
|
||||||
"num_edges",
|
|
||||||
"has_self_loops",
|
|
||||||
"has_isolated_nodes",
|
|
||||||
"num_nodes_features",
|
|
||||||
"y",
|
|
||||||
]
|
|
||||||
maps = {
|
|
||||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
|
||||||
for name in names
|
|
||||||
}
|
|
||||||
return maps
|
|
||||||
|
|
||||||
def __call__(self, data):
|
|
||||||
data_ = data.__copy__()
|
|
||||||
stats = {}
|
|
||||||
for k, v in self.maps.items():
|
|
||||||
if k == "networkx":
|
|
||||||
_data_ = to_networkx(data)
|
|
||||||
_data_ = _data_.to_undirected()
|
|
||||||
elif k == "torch_geometric":
|
|
||||||
_data_ = data.__copy__()
|
|
||||||
for name, f in v.items():
|
|
||||||
if f is None:
|
|
||||||
stats[name] = None
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
val = f(_data_)
|
|
||||||
t1 = time.time()
|
|
||||||
delta = t1 - t0
|
|
||||||
except Exception as e:
|
|
||||||
print(name, e)
|
|
||||||
with open(f"{name}.txt", "w") as f:
|
|
||||||
f.write(str(e))
|
|
||||||
# print(name, round(delta, 4))
|
|
||||||
# if callable(val) and k == "torch_geometric":
|
|
||||||
# val = val()
|
|
||||||
# if isinstance(val, types.GeneratorType):
|
|
||||||
# val = list(val)
|
|
||||||
# stats[name] = val
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
from torch_geometric.datasets import KarateClub, Planetoid
|
|
||||||
|
|
||||||
d = Planetoid(root="/tmp/", name="Cora")
|
|
||||||
# d = KarateClub()
|
|
||||||
a = d[0]
|
|
||||||
st = GraphStat()
|
|
||||||
stat = st(a)
|
|
||||||
for k, v in stat.items():
|
|
||||||
print("---------")
|
|
||||||
print("Name:", k)
|
|
||||||
print("Type:", type(v))
|
|
||||||
print("Val:", v)
|
|
||||||
print("---------")
|
|
Loading…
Add table
Reference in a new issue