Merge branch 'master' of gitlab.xlim.fr:araison/explaining_framework
This commit is contained in:
commit
b73f087a6a
Binary file not shown.
|
@ -0,0 +1,144 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||
GuidedGradCam, InputXGradient, IntegratedGradient,
|
||||
Lime, Occlusion, Saliency)
|
||||
from torch import Tensor
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explanation
|
||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||
ModelConfig, ModelMode,
|
||||
ModelTaskLevel)
|
||||
from torch_geometric.nn import GCNConv, global_mean_pooling
|
||||
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
|
||||
to_captum_input, to_captum_model)
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FromCaptum(ExplainerAlgorithm):
|
||||
def __init__(captum_method, mask_type: str = "node", **kwargs):
|
||||
super().__init__()
|
||||
self.captum_model = captum_model
|
||||
self.mask_type = mask_type
|
||||
_raise_on_invalid_mask_type(mask_type)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def supports(self) -> bool:
|
||||
task_level = self.model_config.task_level
|
||||
if task_level not in [ModelTaskLevel.graph]:
|
||||
logging.error(f"Task level '{task_level.value}' not supported")
|
||||
return False
|
||||
|
||||
edge_mask_type = self.explainer_config.edge_mask_type
|
||||
if edge_mask_type not in [MaskType.object, None]:
|
||||
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
|
||||
return False
|
||||
|
||||
node_mask_type = self.explainer_config.node_mask_type
|
||||
if node_mask_type not in [MaskType.attributes]:
|
||||
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def attr_to_tuple(self, attr, mask_type):
|
||||
_raise_on_invalid_mask_type(mask_type)
|
||||
if "node" == mask_type:
|
||||
node_mask = attr
|
||||
edge_mask = None
|
||||
|
||||
if "edge" == mask_type:
|
||||
node_mask = None
|
||||
edge_mask = attr
|
||||
|
||||
if "node_and_mask" == mask_type:
|
||||
node_mask = attr[0]
|
||||
edge_mask = attr[1]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
target: None,
|
||||
) -> Explanation:
|
||||
|
||||
converted_model = to_captum_model(
|
||||
model,
|
||||
mask_type=self.mask_type,
|
||||
output_idx=target,
|
||||
)
|
||||
attrib = self.captum_model(converted_model)
|
||||
|
||||
inputs, additional_forward_args = to_captum_input(
|
||||
x, edge_index, mask_type=self.mask_type
|
||||
)
|
||||
attr = attrib.attribute(
|
||||
inputs,
|
||||
target=target,
|
||||
additional_forward_args=additional_forward_args,
|
||||
**self.kwargs,
|
||||
)
|
||||
node_mask, edge_mask = self.attr_to_tuple(attr, self.mask_type)
|
||||
|
||||
return Explanation(
|
||||
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
|
||||
)
|
||||
|
||||
|
||||
if "__name__" == __main__:
|
||||
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
|
||||
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
|
||||
y = torch.tensor([1], dtype=torch.long)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, y=y)
|
||||
batch = Batch().from_data_list([data])
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.dim_in = dim_in
|
||||
self.dim_out = dim_out
|
||||
self.conv = GCNConv(dim_in, dim_out)
|
||||
|
||||
def forward(self, batch):
|
||||
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
|
||||
x, edge_index = self.conv(x, edge_index)
|
||||
x = global_mean_pooling(x, edge_index, batch)
|
||||
return x
|
||||
|
||||
model = Model(1,2)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||
|
||||
for epoch in range(1, 2):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
log_logits = model(batch)
|
||||
loss = F.mse_loss(out,torch.ones(2))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
expla = [LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||
GuidedGradCam, InputXGradient, IntegratedGradient,
|
||||
Lime, Occlusion, Saliency]
|
||||
model = model.eval()
|
||||
for captum_exp in expla:
|
||||
try:
|
||||
exp = FromCaptum(captum_exp)
|
||||
attr = exp.forward(model,batch.x,batch.edge_index)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from torch_geometric.data import Data, Batch, Dataset
|
||||
from torch_geometric.utils import to_networkx
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
class DatasetStatistic(object):
|
||||
def __init__(self, dataset):
|
||||
if not isinstance(dataset, Dataset):
|
||||
raise ValueError(f'{dataset} needs to be an PyG dataset object')
|
||||
self.dataset = dataset
|
||||
|
||||
if len(dataset) == 0:
|
||||
self.task = 'node'
|
||||
else:
|
||||
self.task = 'graph'
|
||||
self.study = None
|
||||
|
||||
def save_csv(self, path):
|
||||
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from inspect import getmembers, isfunction, signature
|
||||
import networkx as nx
|
||||
from docstring_parser import parse
|
||||
|
||||
functions_list = getmembers(nx, isfunction)
|
||||
rt_type =[]
|
||||
for func in functions_list:
|
||||
name, f = func
|
||||
docstring = parse(f.__doc__)
|
||||
try:
|
||||
rt = docstring.returns.type_name
|
||||
rt_type.append(rt)
|
||||
if rt == 'int' or rt == 'float' or rt=='bool' or rt=='boolean' or rt=='dictionary':
|
||||
print(f'{name} : {rt}')
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
print('int', rt_type.count('int'))
|
||||
print('float', rt_type.count('float'))
|
||||
print('bool', rt_type.count('bool'))
|
||||
print('boolean', rt_type.count('boolean'))
|
||||
print('dictionary', rt_type.count('dictionary'))
|
||||
print('length func',len(functions_list))
|
||||
print('length rt',len(rt_type))
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import types
|
||||
from inspect import getmembers, isfunction, signature
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.utils import to_networkx
|
||||
|
||||
__maps__ = [
|
||||
"adamic_adar_index",
|
||||
"approximate_current_flow_betweenness_centrality",
|
||||
"average_clustering",
|
||||
"average_degree_connectivity",
|
||||
"average_neighbor_degree",
|
||||
"average_node_connectivity",
|
||||
"average_shortest_path_length",
|
||||
"betweenness_centrality",
|
||||
"bridges",
|
||||
"closeness_centrality",
|
||||
"clustering",
|
||||
"cn_soundarajan_hopcroft",
|
||||
"common_neighbor_centrality",
|
||||
"communicability",
|
||||
"communicability_betweenness_centrality",
|
||||
"communicability_exp",
|
||||
"connected_components",
|
||||
"constraint",
|
||||
"core_number",
|
||||
"current_flow_betweenness_centrality",
|
||||
"current_flow_closeness_centrality",
|
||||
"cycle_basis",
|
||||
"degree_assortativity_coefficient",
|
||||
"degree_centrality",
|
||||
"degree_mixing_matrix",
|
||||
"degree_pearson_correlation_coefficient",
|
||||
"diameter",
|
||||
"dispersion",
|
||||
"dominating_set",
|
||||
"eccentricity",
|
||||
"effective_size",
|
||||
"eigenvector_centrality",
|
||||
"estrada_index",
|
||||
"generalized_degree",
|
||||
"global_efficiency",
|
||||
"global_reaching_centrality",
|
||||
"graph_clique_number",
|
||||
"harmonic_centrality",
|
||||
"has_bridges",
|
||||
"has_eulerian_path",
|
||||
"hits",
|
||||
"information_centrality",
|
||||
"is_at_free",
|
||||
"is_biconnected",
|
||||
"is_bipartite",
|
||||
"is_chordal",
|
||||
"is_connected",
|
||||
"is_directed_acyclic_graph",
|
||||
"is_distance_regular",
|
||||
"is_eulerian",
|
||||
"is_forest",
|
||||
"is_graphical",
|
||||
"is_multigraphical",
|
||||
"is_planar",
|
||||
"is_pseudographical",
|
||||
"is_regular",
|
||||
"is_semieulerian",
|
||||
"is_strongly_regular",
|
||||
"is_tree",
|
||||
"is_valid_degree_sequence_erdos_gallai",
|
||||
"is_valid_degree_sequence_havel_hakimi",
|
||||
"jaccard_coefficient",
|
||||
"k_components",
|
||||
"k_core",
|
||||
"katz_centrality",
|
||||
"load_centrality",
|
||||
"minimum_cycle_basis",
|
||||
"minimum_edge_cut",
|
||||
"minimum_node_cut",
|
||||
"node_clique_number",
|
||||
"node_connectivity",
|
||||
"node_degree_xy",
|
||||
"number_connected_components",
|
||||
"number_of_cliques",
|
||||
"number_of_isolates",
|
||||
"pagerank",
|
||||
"periphery",
|
||||
"preferential_attachment",
|
||||
"reciprocity",
|
||||
"resource_allocation_index",
|
||||
"rich_club_coefficient",
|
||||
"square_clustering",
|
||||
"stoer_wagner",
|
||||
"subgraph_centrality",
|
||||
"subgraph_centrality_exp",
|
||||
"transitivity",
|
||||
"triangles",
|
||||
"voterank",
|
||||
"wiener_index",
|
||||
"algebraic_connectivity",
|
||||
"degree",
|
||||
"density",
|
||||
"normalized_laplacian_spectrum",
|
||||
"number_of_edges",
|
||||
"number_of_nodes",
|
||||
"number_of_selfloops",
|
||||
]
|
||||
|
||||
|
||||
class GraphStat(object):
|
||||
def __init__(self):
|
||||
self.maps = {
|
||||
"networkx": self.available_map_networkx(),
|
||||
"torch_geometric": self.available_map_torch_geometric(),
|
||||
}
|
||||
|
||||
def available_map_networkx(self):
|
||||
functions_list = getmembers(nx, isfunction)
|
||||
|
||||
maps = {}
|
||||
for func in functions_list:
|
||||
name, f = func
|
||||
if name in __maps__:
|
||||
maps[name] = f
|
||||
return maps
|
||||
|
||||
def available_map_torch_geometric(self):
|
||||
names = [
|
||||
"num_nodes",
|
||||
"num_edges",
|
||||
"has_self_loops",
|
||||
"has_isolated_nodes",
|
||||
"num_nodes_features",
|
||||
"y",
|
||||
]
|
||||
maps = {
|
||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||
for name in names
|
||||
}
|
||||
return maps
|
||||
|
||||
def __call__(self, data):
|
||||
data_ = data.__copy__()
|
||||
datahash = hash(data.__repr__)
|
||||
stats = {}
|
||||
for k, v in self.maps.items():
|
||||
if k == "networkx":
|
||||
_data_ = to_networkx(data)
|
||||
_data_ = _data_.to_undirected()
|
||||
elif k == "torch_geometric":
|
||||
_data_ = data.__copy__()
|
||||
for name, func in v.items():
|
||||
try:
|
||||
val = func(_data_)
|
||||
except:
|
||||
val = None
|
||||
if callable(val):
|
||||
val = val()
|
||||
if isinstance(val, types.GeneratorType):
|
||||
try:
|
||||
val = dict(val)
|
||||
except:
|
||||
val = None
|
||||
if name == "hits":
|
||||
val = val[0]
|
||||
if name == "k_components":
|
||||
for key, value in val.items():
|
||||
val[key] = list(val[key][0])
|
||||
|
||||
stats[name] = self.convert(val)
|
||||
stats["hash"] = datahash
|
||||
return stats
|
||||
|
||||
def convert(self, val, K=4):
|
||||
|
||||
if type(val) == set:
|
||||
val = list(val)
|
||||
return val
|
||||
|
||||
elif isinstance(val, torch.Tensor):
|
||||
val = val.cpu().numpy().tolist()
|
||||
return val
|
||||
|
||||
elif isinstance(val, list):
|
||||
for ind in range(len(val)):
|
||||
item = val[ind]
|
||||
if isinstance(item, list):
|
||||
for subind in range(len(item)):
|
||||
subitem = item[subind]
|
||||
if type(subitem) == float:
|
||||
item[subind] = round(subitem, K)
|
||||
|
||||
else:
|
||||
if type(item) == float:
|
||||
val[ind] = round(item, K)
|
||||
return val
|
||||
|
||||
elif isinstance(val, dict):
|
||||
for k, v in val.items():
|
||||
if isinstance(v, dict):
|
||||
for k1, v1 in v.items():
|
||||
if type(v1) == float:
|
||||
v[k1] = round(v1, K)
|
||||
if type(v1) == set:
|
||||
v[k1] = list(v1)
|
||||
|
||||
else:
|
||||
if type(v) == float:
|
||||
val[k] = round(v, K)
|
||||
return val
|
||||
|
||||
elif isinstance(val, nx.classes.reportviews.DegreeView):
|
||||
val = np.array(val).tolist()
|
||||
return val
|
||||
|
||||
elif isinstance(val, nx.Graph):
|
||||
val = nx.node_link_data(val)
|
||||
return val
|
||||
|
||||
elif isinstance(val, np.ndarray):
|
||||
val = np.around(val, decimals=K)
|
||||
return val.tolist()
|
||||
|
||||
else:
|
||||
return val
|
|
@ -0,0 +1,3 @@
|
|||
class BaseExplaining(object):
|
||||
def __init__(self,model,explainer_name:wq
|
||||
|
|
@ -0,0 +1 @@
|
|||
from torch_geometric.nn.models.captum import CaptumModel
|
Loading…
Reference in New Issue