New fixes and new features
This commit is contained in:
parent
ea0a5dd86e
commit
45730328a7
@ -188,21 +188,21 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||
|
||||
def _parse_attr(self, attr):
|
||||
if self.mask_type == "node":
|
||||
node_mask = attr[0].squeeze(0)
|
||||
node_feat_mask = attr[0].squeeze(0)
|
||||
edge_mask = None
|
||||
|
||||
if self.mask_type == "edge":
|
||||
node_mask = None
|
||||
node_feat_mask = None
|
||||
edge_mask = attr[0]
|
||||
|
||||
if self.mask_type == "node_and_edge":
|
||||
node_mask = attr[0].squeeze(0)
|
||||
node_feat_mask = attr[0].squeeze(0)
|
||||
edge_mask = attr[1]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
edge_feat_mask = None
|
||||
node_feat_mask = None
|
||||
node_mask = None
|
||||
|
||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||
|
||||
@ -238,6 +238,7 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||
return Explanation(
|
||||
x=x,
|
||||
edge_index=edge_index,
|
||||
y=target,
|
||||
edge_mask=edge_mask,
|
||||
node_mask=node_mask,
|
||||
node_feat_mask=node_feat_mask,
|
||||
|
@ -14,7 +14,7 @@ from graphxai.explainers.pgm_explainer import PGMExplainer
|
||||
from graphxai.explainers.random import RandomExplainer
|
||||
from graphxai.explainers.subgraphx import SubgraphX
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, KLDivLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.explain import Explanation
|
||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||
@ -260,6 +260,7 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||
return Explanation(
|
||||
x=x,
|
||||
edge_index=edge_index,
|
||||
y=target,
|
||||
edge_mask=edge_mask,
|
||||
node_mask=node_mask,
|
||||
node_feat_mask=node_feat_mask,
|
||||
|
29
explaining_framework/metric/accuracy.py
Normal file
29
explaining_framework/metric/accuracy.py
Normal file
@ -0,0 +1,29 @@
|
||||
import sklearn.metrics
|
||||
import torch
|
||||
|
||||
from base import Metric
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
def __init__(name: str):
|
||||
super().__init__(name=name, model=None)
|
||||
self.authorized_metric = [
|
||||
"precision_score",
|
||||
"precision_score",
|
||||
"jaccard_score",
|
||||
"roc_auc_score",
|
||||
"f1_score",
|
||||
"accuracy_score",
|
||||
]
|
||||
|
||||
self.metric = self.load_metric(name)
|
||||
|
||||
def load_metric(name):
|
||||
if name in self.authorized_metric:
|
||||
self.metric = eval("sklearn.metric.{name}")
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
|
||||
def forward(self, mask, target: Tensor) -> float:
|
||||
if mask.type() == torch.bool and target.type() == torch.bool:
|
||||
return self.metric(y_pred=mask, y_true=target)
|
@ -7,6 +7,7 @@ class Metric(ABC):
|
||||
self.model = model
|
||||
if is_model_needed and model is None:
|
||||
raise ValueError(f"{self.name} needs model to perform measurements")
|
||||
self.authorized_metric = None
|
||||
|
||||
def is_model_needed(self):
|
||||
if "fidelity" in self.name:
|
||||
@ -14,6 +15,31 @@ class Metric(ABC):
|
||||
else:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def load_metric(name: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, exp: Explanation, **kwargs) -> float:
|
||||
pass
|
||||
|
||||
def get_prediction(self, *args, **kwargs) -> torch.Tensor:
|
||||
r"""Returns the prediction of the model on the input graph.
|
||||
If the model mode is :obj:`"regression"`, the prediction is returned as
|
||||
a scalar value.
|
||||
If the model mode :obj:`"classification"`, the prediction is returned
|
||||
as the predicted class label.
|
||||
Args:
|
||||
*args: Arguments passed to the model.
|
||||
**kwargs (optional): Additional keyword arguments passed to the
|
||||
model.
|
||||
"""
|
||||
training = self.model.training
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
out = self.model(*args, **kwargs)
|
||||
|
||||
self.model.train(training)
|
||||
|
||||
return out
|
||||
|
66
explaining_framework/metric/fidelity.py
Normal file
66
explaining_framework/metric/fidelity.py
Normal file
@ -0,0 +1,66 @@
|
||||
import torch
|
||||
from torch.nn import KLDivLoss, Softmax
|
||||
|
||||
from base import Metric
|
||||
|
||||
|
||||
class Fidelity(Metric):
|
||||
def __init__(name: str, model: torch.nn.Module, mask_type: str):
|
||||
super().__init__(name=name, model=model)
|
||||
self.authorized_metric = [
|
||||
"fidelity_plus",
|
||||
"fidelity_minus",
|
||||
"fidelity_plus_prob",
|
||||
"fidelity_minus_prob",
|
||||
"fidelity_plus_model",
|
||||
"fidelity_minus_model",
|
||||
"fidelity_plus_prob_model",
|
||||
"fidelity_minus_prob_model",
|
||||
"infidelity_KL",
|
||||
]
|
||||
|
||||
self.metric = self.load_metric(name)
|
||||
self.exp_sub = None
|
||||
self.exp_sub_c = None
|
||||
self.s_exp_sub = None
|
||||
self.s_exp_sub_c = None
|
||||
self.s_initial_data = None
|
||||
|
||||
def _fidelity_plus(self, exp) -> float:
|
||||
if any(
|
||||
[
|
||||
attr is None
|
||||
for attr in [
|
||||
self.exp_sub,
|
||||
self.exp_sub_c,
|
||||
self.s_exp_sub,
|
||||
self.s_exp_sub_c,
|
||||
self.s_initial_data,
|
||||
]
|
||||
]
|
||||
):
|
||||
self.score(exp)
|
||||
else:
|
||||
fid = self.s_initial_data - self.s_exp_sub_c
|
||||
|
||||
def score(self, exp):
|
||||
self.exp_sub = exp.get_explanation_subgraph()
|
||||
self.exp_sub_c = exp.get_complement_subgraph()
|
||||
self.s_exp_sub = self.get_prediction(self.exp_sub)
|
||||
self.s_exp_sub_c = self.get_prediction(self.exp_sub_c)
|
||||
self.s_initial_data = self.get_prediction(exp)
|
||||
|
||||
def load_metric(name):
|
||||
if name in self.authorized_metric:
|
||||
if name == "fidelity_plus":
|
||||
|
||||
self.metric = eval("sklearn.metric.{name}")
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
|
||||
def compute(self, mask, target: Tensor) -> float:
|
||||
if mask.type() == torch.bool and target.type() == torch.bool:
|
||||
return self.metric(y_pred=mask, y_true=target)
|
||||
|
||||
def __call__(self, exp: Explanation):
|
||||
pass
|
4
explaining_framework/metric/robust.py
Normal file
4
explaining_framework/metric/robust.py
Normal file
@ -0,0 +1,4 @@
|
||||
class Attack(Metric):
|
||||
|
||||
|
||||
name in ['gaussian noise attack', 'edge perturbation attack', 'pgm', 'fgsd']:wq
|
@ -1,3 +0,0 @@
|
||||
class BaseExplaining(object):
|
||||
def __init__(self,model,explainer_name:wq
|
||||
|
@ -1 +0,0 @@
|
||||
from torch_geometric.nn.models.captum import CaptumModel
|
16
explaining_framework/utils/explanation_adjust.py
Normal file
16
explaining_framework/utils/explanation_adjust.py
Normal file
@ -0,0 +1,16 @@
|
||||
import copy
|
||||
|
||||
from torch import FloatTensor
|
||||
from torch.nn import ReLU
|
||||
|
||||
|
||||
def relu_mask(explanation: Explanation) -> Explanation:
|
||||
relu = ReLU()
|
||||
explanation_store = explanation._store
|
||||
raw_data = copy.copy(explanation._store)
|
||||
for k, v in explanation_store.items():
|
||||
if "mask" in k:
|
||||
explanation_store[k] = relu(v)
|
||||
explanation.__setattr__("raw_explanation", raw_data)
|
||||
explanation.__setattr__("raw_explanation_transform", "relu")
|
||||
return explanation
|
7
explaining_framework/utils/explanation_threshold.py
Normal file
7
explaining_framework/utils/explanation_threshold.py
Normal file
@ -0,0 +1,7 @@
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
|
||||
from torch_geometric.explain.explanation import Explanation
|
125
utils/stat.py
125
utils/stat.py
@ -1,125 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from inspect import getmembers, isfunction, signature
|
||||
|
||||
import networkx as nx
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.utils import to_networkx
|
||||
|
||||
with open("mapping_nx.txt", "r") as file:
|
||||
BLACK_LIST = [line.rstrip() for line in file]
|
||||
|
||||
|
||||
class GraphStat(object):
|
||||
def __init__(self):
|
||||
self.maps = {
|
||||
"networkx": self.available_map_networkx(),
|
||||
"torch_geometric": self.available_map_torch_geometric(),
|
||||
}
|
||||
|
||||
def available_map_networkx(self):
|
||||
functions_list = getmembers(nx.algorithms, isfunction)
|
||||
MANUALLY_ADDED = [
|
||||
"algebraic_connectivity",
|
||||
"adjacency_spectrum",
|
||||
"degree",
|
||||
"density",
|
||||
"laplacian_spectrum",
|
||||
"normalized_laplacian_spectrum",
|
||||
"number_of_selfloops",
|
||||
"number_of_edges",
|
||||
"number_of_nodes",
|
||||
]
|
||||
MANUALLY_ADDED_LIST = [
|
||||
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
|
||||
]
|
||||
functions_list = functions_list + MANUALLY_ADDED_LIST
|
||||
maps = {}
|
||||
for func in functions_list:
|
||||
name, f = func
|
||||
if (
|
||||
name in BLACK_LIST
|
||||
or name == "recursive_simple_cycles"
|
||||
or "triad" in name
|
||||
or "weisfeiler" in name
|
||||
or "dfs" in name
|
||||
or "trophic" in name
|
||||
or "recursive" in name
|
||||
or "scipy" in name
|
||||
or "numpy" in name
|
||||
or "sigma" in name
|
||||
or "omega" in name
|
||||
or "all_" in name
|
||||
):
|
||||
continue
|
||||
else:
|
||||
maps[name] = f
|
||||
return maps
|
||||
|
||||
def available_map_torch_geometric(self):
|
||||
names = [
|
||||
"num_nodes",
|
||||
"num_edges",
|
||||
"has_self_loops",
|
||||
"has_isolated_nodes",
|
||||
"num_nodes_features",
|
||||
"y",
|
||||
]
|
||||
maps = {
|
||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||
for name in names
|
||||
}
|
||||
return maps
|
||||
|
||||
def __call__(self, data):
|
||||
data_ = data.__copy__()
|
||||
stats = {}
|
||||
for k, v in self.maps.items():
|
||||
if k == "networkx":
|
||||
_data_ = to_networkx(data)
|
||||
_data_ = _data_.to_undirected()
|
||||
elif k == "torch_geometric":
|
||||
_data_ = data.__copy__()
|
||||
for name, f in v.items():
|
||||
if f is None:
|
||||
stats[name] = None
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
t0 = time.time()
|
||||
val = f(_data_)
|
||||
t1 = time.time()
|
||||
delta = t1 - t0
|
||||
except Exception as e:
|
||||
print(name, e)
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(str(e))
|
||||
# print(name, round(delta, 4))
|
||||
# if callable(val) and k == "torch_geometric":
|
||||
# val = val()
|
||||
# if isinstance(val, types.GeneratorType):
|
||||
# val = list(val)
|
||||
# stats[name] = val
|
||||
return stats
|
||||
|
||||
|
||||
from torch_geometric.datasets import KarateClub, Planetoid
|
||||
|
||||
d = Planetoid(root="/tmp/", name="Cora")
|
||||
# d = KarateClub()
|
||||
a = d[0]
|
||||
st = GraphStat()
|
||||
stat = st(a)
|
||||
for k, v in stat.items():
|
||||
print("---------")
|
||||
print("Name:", k)
|
||||
print("Type:", type(v))
|
||||
print("Val:", v)
|
||||
print("---------")
|
Loading…
Reference in New Issue
Block a user