Merge branch 'master' of gitlab.xlim.fr:araison/explaining_framework
This commit is contained in:
commit
b73f087a6a
43 changed files with 428 additions and 0 deletions
0
explaining_framework/__init__.py
Normal file
0
explaining_framework/__init__.py
Normal file
0
explaining_framework/explainers/__init__.py
Normal file
0
explaining_framework/explainers/__init__.py
Normal file
Binary file not shown.
144
explaining_framework/explainers/algorithms/deconvolution.py
Normal file
144
explaining_framework/explainers/algorithms/deconvolution.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||||
|
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||||
|
GuidedGradCam, InputXGradient, IntegratedGradient,
|
||||||
|
Lime, Occlusion, Saliency)
|
||||||
|
from torch import Tensor
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explanation
|
||||||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
|
ModelConfig, ModelMode,
|
||||||
|
ModelTaskLevel)
|
||||||
|
from torch_geometric.nn import GCNConv, global_mean_pooling
|
||||||
|
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
|
||||||
|
to_captum_input, to_captum_model)
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class FromCaptum(ExplainerAlgorithm):
|
||||||
|
def __init__(captum_method, mask_type: str = "node", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.captum_model = captum_model
|
||||||
|
self.mask_type = mask_type
|
||||||
|
_raise_on_invalid_mask_type(mask_type)
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def supports(self) -> bool:
|
||||||
|
task_level = self.model_config.task_level
|
||||||
|
if task_level not in [ModelTaskLevel.graph]:
|
||||||
|
logging.error(f"Task level '{task_level.value}' not supported")
|
||||||
|
return False
|
||||||
|
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
if edge_mask_type not in [MaskType.object, None]:
|
||||||
|
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
|
||||||
|
return False
|
||||||
|
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
if node_mask_type not in [MaskType.attributes]:
|
||||||
|
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def attr_to_tuple(self, attr, mask_type):
|
||||||
|
_raise_on_invalid_mask_type(mask_type)
|
||||||
|
if "node" == mask_type:
|
||||||
|
node_mask = attr
|
||||||
|
edge_mask = None
|
||||||
|
|
||||||
|
if "edge" == mask_type:
|
||||||
|
node_mask = None
|
||||||
|
edge_mask = attr
|
||||||
|
|
||||||
|
if "node_and_mask" == mask_type:
|
||||||
|
node_mask = attr[0]
|
||||||
|
edge_mask = attr[1]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
target: None,
|
||||||
|
) -> Explanation:
|
||||||
|
|
||||||
|
converted_model = to_captum_model(
|
||||||
|
model,
|
||||||
|
mask_type=self.mask_type,
|
||||||
|
output_idx=target,
|
||||||
|
)
|
||||||
|
attrib = self.captum_model(converted_model)
|
||||||
|
|
||||||
|
inputs, additional_forward_args = to_captum_input(
|
||||||
|
x, edge_index, mask_type=self.mask_type
|
||||||
|
)
|
||||||
|
attr = attrib.attribute(
|
||||||
|
inputs,
|
||||||
|
target=target,
|
||||||
|
additional_forward_args=additional_forward_args,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
|
node_mask, edge_mask = self.attr_to_tuple(attr, self.mask_type)
|
||||||
|
|
||||||
|
return Explanation(
|
||||||
|
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if "__name__" == __main__:
|
||||||
|
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
|
||||||
|
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
|
||||||
|
y = torch.tensor([1], dtype=torch.long)
|
||||||
|
|
||||||
|
data = Data(x=x, edge_index=edge_index, y=y)
|
||||||
|
batch = Batch().from_data_list([data])
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_in = dim_in
|
||||||
|
self.dim_out = dim_out
|
||||||
|
self.conv = GCNConv(dim_in, dim_out)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
|
||||||
|
x, edge_index = self.conv(x, edge_index)
|
||||||
|
x = global_mean_pooling(x, edge_index, batch)
|
||||||
|
return x
|
||||||
|
|
||||||
|
model = Model(1,2)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||||
|
|
||||||
|
for epoch in range(1, 2):
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
log_logits = model(batch)
|
||||||
|
loss = F.mse_loss(out,torch.ones(2))
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
expla = [LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||||
|
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||||
|
GuidedGradCam, InputXGradient, IntegratedGradient,
|
||||||
|
Lime, Occlusion, Saliency]
|
||||||
|
model = model.eval()
|
||||||
|
for captum_exp in expla:
|
||||||
|
try:
|
||||||
|
exp = FromCaptum(captum_exp)
|
||||||
|
attr = exp.forward(model,batch.x,batch.edge_index)
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
0
explaining_framework/explainers/algorithms/deeplift.py
Normal file
0
explaining_framework/explainers/algorithms/deeplift.py
Normal file
0
explaining_framework/explainers/algorithms/eixgnn.py
Normal file
0
explaining_framework/explainers/algorithms/eixgnn.py
Normal file
0
explaining_framework/explainers/algorithms/graphsvx.py
Normal file
0
explaining_framework/explainers/algorithms/graphsvx.py
Normal file
0
explaining_framework/explainers/algorithms/lime.py
Normal file
0
explaining_framework/explainers/algorithms/lime.py
Normal file
0
explaining_framework/explainers/algorithms/lrp.py
Normal file
0
explaining_framework/explainers/algorithms/lrp.py
Normal file
0
explaining_framework/explainers/algorithms/occlusion.py
Normal file
0
explaining_framework/explainers/algorithms/occlusion.py
Normal file
0
explaining_framework/explainers/algorithms/random.py
Normal file
0
explaining_framework/explainers/algorithms/random.py
Normal file
0
explaining_framework/explainers/algorithms/saliency.py
Normal file
0
explaining_framework/explainers/algorithms/saliency.py
Normal file
0
explaining_framework/explainers/algorithms/subgraphx.py
Normal file
0
explaining_framework/explainers/algorithms/subgraphx.py
Normal file
0
explaining_framework/explainers/wrappers/__init__.py
Normal file
0
explaining_framework/explainers/wrappers/__init__.py
Normal file
0
explaining_framework/explainers/wrappers/captum.py
Normal file
0
explaining_framework/explainers/wrappers/captum.py
Normal file
0
explaining_framework/explainers/wrappers/dig.py
Normal file
0
explaining_framework/explainers/wrappers/dig.py
Normal file
0
explaining_framework/explainers/wrappers/graphframex.py
Normal file
0
explaining_framework/explainers/wrappers/graphframex.py
Normal file
0
explaining_framework/explainers/wrappers/graphxai.py
Normal file
0
explaining_framework/explainers/wrappers/graphxai.py
Normal file
0
explaining_framework/metric/__init__.py
Normal file
0
explaining_framework/metric/__init__.py
Normal file
24
explaining_framework/metric/dataset.py
Normal file
24
explaining_framework/metric/dataset.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from torch_geometric.data import Data, Batch, Dataset
|
||||||
|
from torch_geometric.utils import to_networkx
|
||||||
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class DatasetStatistic(object):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
if not isinstance(dataset, Dataset):
|
||||||
|
raise ValueError(f'{dataset} needs to be an PyG dataset object')
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
if len(dataset) == 0:
|
||||||
|
self.task = 'node'
|
||||||
|
else:
|
||||||
|
self.task = 'graph'
|
||||||
|
self.study = None
|
||||||
|
|
||||||
|
def save_csv(self, path):
|
||||||
|
|
||||||
|
|
28
explaining_framework/metric/test.py
Normal file
28
explaining_framework/metric/test.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from inspect import getmembers, isfunction, signature
|
||||||
|
import networkx as nx
|
||||||
|
from docstring_parser import parse
|
||||||
|
|
||||||
|
functions_list = getmembers(nx, isfunction)
|
||||||
|
rt_type =[]
|
||||||
|
for func in functions_list:
|
||||||
|
name, f = func
|
||||||
|
docstring = parse(f.__doc__)
|
||||||
|
try:
|
||||||
|
rt = docstring.returns.type_name
|
||||||
|
rt_type.append(rt)
|
||||||
|
if rt == 'int' or rt == 'float' or rt=='bool' or rt=='boolean' or rt=='dictionary':
|
||||||
|
print(f'{name} : {rt}')
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print('int', rt_type.count('int'))
|
||||||
|
print('float', rt_type.count('float'))
|
||||||
|
print('bool', rt_type.count('bool'))
|
||||||
|
print('boolean', rt_type.count('boolean'))
|
||||||
|
print('dictionary', rt_type.count('dictionary'))
|
||||||
|
print('length func',len(functions_list))
|
||||||
|
print('length rt',len(rt_type))
|
||||||
|
|
0
explaining_framework/stats/__init__.py
Normal file
0
explaining_framework/stats/__init__.py
Normal file
0
explaining_framework/stats/dataset/__init__.py
Normal file
0
explaining_framework/stats/dataset/__init__.py
Normal file
0
explaining_framework/stats/graph/__init__.py
Normal file
0
explaining_framework/stats/graph/__init__.py
Normal file
228
explaining_framework/stats/graph/graph_stat.py
Normal file
228
explaining_framework/stats/graph/graph_stat.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import types
|
||||||
|
from inspect import getmembers, isfunction, signature
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.utils import to_networkx
|
||||||
|
|
||||||
|
__maps__ = [
|
||||||
|
"adamic_adar_index",
|
||||||
|
"approximate_current_flow_betweenness_centrality",
|
||||||
|
"average_clustering",
|
||||||
|
"average_degree_connectivity",
|
||||||
|
"average_neighbor_degree",
|
||||||
|
"average_node_connectivity",
|
||||||
|
"average_shortest_path_length",
|
||||||
|
"betweenness_centrality",
|
||||||
|
"bridges",
|
||||||
|
"closeness_centrality",
|
||||||
|
"clustering",
|
||||||
|
"cn_soundarajan_hopcroft",
|
||||||
|
"common_neighbor_centrality",
|
||||||
|
"communicability",
|
||||||
|
"communicability_betweenness_centrality",
|
||||||
|
"communicability_exp",
|
||||||
|
"connected_components",
|
||||||
|
"constraint",
|
||||||
|
"core_number",
|
||||||
|
"current_flow_betweenness_centrality",
|
||||||
|
"current_flow_closeness_centrality",
|
||||||
|
"cycle_basis",
|
||||||
|
"degree_assortativity_coefficient",
|
||||||
|
"degree_centrality",
|
||||||
|
"degree_mixing_matrix",
|
||||||
|
"degree_pearson_correlation_coefficient",
|
||||||
|
"diameter",
|
||||||
|
"dispersion",
|
||||||
|
"dominating_set",
|
||||||
|
"eccentricity",
|
||||||
|
"effective_size",
|
||||||
|
"eigenvector_centrality",
|
||||||
|
"estrada_index",
|
||||||
|
"generalized_degree",
|
||||||
|
"global_efficiency",
|
||||||
|
"global_reaching_centrality",
|
||||||
|
"graph_clique_number",
|
||||||
|
"harmonic_centrality",
|
||||||
|
"has_bridges",
|
||||||
|
"has_eulerian_path",
|
||||||
|
"hits",
|
||||||
|
"information_centrality",
|
||||||
|
"is_at_free",
|
||||||
|
"is_biconnected",
|
||||||
|
"is_bipartite",
|
||||||
|
"is_chordal",
|
||||||
|
"is_connected",
|
||||||
|
"is_directed_acyclic_graph",
|
||||||
|
"is_distance_regular",
|
||||||
|
"is_eulerian",
|
||||||
|
"is_forest",
|
||||||
|
"is_graphical",
|
||||||
|
"is_multigraphical",
|
||||||
|
"is_planar",
|
||||||
|
"is_pseudographical",
|
||||||
|
"is_regular",
|
||||||
|
"is_semieulerian",
|
||||||
|
"is_strongly_regular",
|
||||||
|
"is_tree",
|
||||||
|
"is_valid_degree_sequence_erdos_gallai",
|
||||||
|
"is_valid_degree_sequence_havel_hakimi",
|
||||||
|
"jaccard_coefficient",
|
||||||
|
"k_components",
|
||||||
|
"k_core",
|
||||||
|
"katz_centrality",
|
||||||
|
"load_centrality",
|
||||||
|
"minimum_cycle_basis",
|
||||||
|
"minimum_edge_cut",
|
||||||
|
"minimum_node_cut",
|
||||||
|
"node_clique_number",
|
||||||
|
"node_connectivity",
|
||||||
|
"node_degree_xy",
|
||||||
|
"number_connected_components",
|
||||||
|
"number_of_cliques",
|
||||||
|
"number_of_isolates",
|
||||||
|
"pagerank",
|
||||||
|
"periphery",
|
||||||
|
"preferential_attachment",
|
||||||
|
"reciprocity",
|
||||||
|
"resource_allocation_index",
|
||||||
|
"rich_club_coefficient",
|
||||||
|
"square_clustering",
|
||||||
|
"stoer_wagner",
|
||||||
|
"subgraph_centrality",
|
||||||
|
"subgraph_centrality_exp",
|
||||||
|
"transitivity",
|
||||||
|
"triangles",
|
||||||
|
"voterank",
|
||||||
|
"wiener_index",
|
||||||
|
"algebraic_connectivity",
|
||||||
|
"degree",
|
||||||
|
"density",
|
||||||
|
"normalized_laplacian_spectrum",
|
||||||
|
"number_of_edges",
|
||||||
|
"number_of_nodes",
|
||||||
|
"number_of_selfloops",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphStat(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.maps = {
|
||||||
|
"networkx": self.available_map_networkx(),
|
||||||
|
"torch_geometric": self.available_map_torch_geometric(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def available_map_networkx(self):
|
||||||
|
functions_list = getmembers(nx, isfunction)
|
||||||
|
|
||||||
|
maps = {}
|
||||||
|
for func in functions_list:
|
||||||
|
name, f = func
|
||||||
|
if name in __maps__:
|
||||||
|
maps[name] = f
|
||||||
|
return maps
|
||||||
|
|
||||||
|
def available_map_torch_geometric(self):
|
||||||
|
names = [
|
||||||
|
"num_nodes",
|
||||||
|
"num_edges",
|
||||||
|
"has_self_loops",
|
||||||
|
"has_isolated_nodes",
|
||||||
|
"num_nodes_features",
|
||||||
|
"y",
|
||||||
|
]
|
||||||
|
maps = {
|
||||||
|
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||||
|
for name in names
|
||||||
|
}
|
||||||
|
return maps
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
data_ = data.__copy__()
|
||||||
|
datahash = hash(data.__repr__)
|
||||||
|
stats = {}
|
||||||
|
for k, v in self.maps.items():
|
||||||
|
if k == "networkx":
|
||||||
|
_data_ = to_networkx(data)
|
||||||
|
_data_ = _data_.to_undirected()
|
||||||
|
elif k == "torch_geometric":
|
||||||
|
_data_ = data.__copy__()
|
||||||
|
for name, func in v.items():
|
||||||
|
try:
|
||||||
|
val = func(_data_)
|
||||||
|
except:
|
||||||
|
val = None
|
||||||
|
if callable(val):
|
||||||
|
val = val()
|
||||||
|
if isinstance(val, types.GeneratorType):
|
||||||
|
try:
|
||||||
|
val = dict(val)
|
||||||
|
except:
|
||||||
|
val = None
|
||||||
|
if name == "hits":
|
||||||
|
val = val[0]
|
||||||
|
if name == "k_components":
|
||||||
|
for key, value in val.items():
|
||||||
|
val[key] = list(val[key][0])
|
||||||
|
|
||||||
|
stats[name] = self.convert(val)
|
||||||
|
stats["hash"] = datahash
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def convert(self, val, K=4):
|
||||||
|
|
||||||
|
if type(val) == set:
|
||||||
|
val = list(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, torch.Tensor):
|
||||||
|
val = val.cpu().numpy().tolist()
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, list):
|
||||||
|
for ind in range(len(val)):
|
||||||
|
item = val[ind]
|
||||||
|
if isinstance(item, list):
|
||||||
|
for subind in range(len(item)):
|
||||||
|
subitem = item[subind]
|
||||||
|
if type(subitem) == float:
|
||||||
|
item[subind] = round(subitem, K)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if type(item) == float:
|
||||||
|
val[ind] = round(item, K)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, dict):
|
||||||
|
for k, v in val.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for k1, v1 in v.items():
|
||||||
|
if type(v1) == float:
|
||||||
|
v[k1] = round(v1, K)
|
||||||
|
if type(v1) == set:
|
||||||
|
v[k1] = list(v1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if type(v) == float:
|
||||||
|
val[k] = round(v, K)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, nx.classes.reportviews.DegreeView):
|
||||||
|
val = np.array(val).tolist()
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, nx.Graph):
|
||||||
|
val = nx.node_link_data(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, np.ndarray):
|
||||||
|
val = np.around(val, decimals=K)
|
||||||
|
return val.tolist()
|
||||||
|
|
||||||
|
else:
|
||||||
|
return val
|
0
explaining_framework/utils/__init__.py
Normal file
0
explaining_framework/utils/__init__.py
Normal file
0
explaining_framework/utils/explainer/__init__.py
Normal file
0
explaining_framework/utils/explainer/__init__.py
Normal file
3
explaining_framework/utils/explainer/base.py
Normal file
3
explaining_framework/utils/explainer/base.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
class BaseExplaining(object):
|
||||||
|
def __init__(self,model,explainer_name:wq
|
||||||
|
|
1
explaining_framework/utils/explainer/from_captum.py
Normal file
1
explaining_framework/utils/explainer/from_captum.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from torch_geometric.nn.models.captum import CaptumModel
|
0
explaining_framework/utils/graphgym/__init__.py
Normal file
0
explaining_framework/utils/graphgym/__init__.py
Normal file
0
explaining_framework/utils/visualizer/__init__.py
Normal file
0
explaining_framework/utils/visualizer/__init__.py
Normal file
Loading…
Add table
Reference in a new issue