Merge branch 'master' of gitlab.xlim.fr:araison/explaining_framework

This commit is contained in:
araison 2022-12-09 14:45:25 +01:00
commit b73f087a6a
43 changed files with 428 additions and 0 deletions

View File

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from typing import Optional, Tuple, Union
import torch
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
FeaturePermutation, GradientShap, GuidedBackprop,
GuidedGradCam, InputXGradient, IntegratedGradient,
Lime, Occlusion, Saliency)
from torch import Tensor
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
from torch_geometric.nn import GCNConv, global_mean_pooling
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
to_captum_input, to_captum_model)
import torch.nn.functional as F
class FromCaptum(ExplainerAlgorithm):
def __init__(captum_method, mask_type: str = "node", **kwargs):
super().__init__()
self.captum_model = captum_model
self.mask_type = mask_type
_raise_on_invalid_mask_type(mask_type)
self.kwargs = kwargs
def supports(self) -> bool:
task_level = self.model_config.task_level
if task_level not in [ModelTaskLevel.graph]:
logging.error(f"Task level '{task_level.value}' not supported")
return False
edge_mask_type = self.explainer_config.edge_mask_type
if edge_mask_type not in [MaskType.object, None]:
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [MaskType.attributes]:
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
return False
return True
def attr_to_tuple(self, attr, mask_type):
_raise_on_invalid_mask_type(mask_type)
if "node" == mask_type:
node_mask = attr
edge_mask = None
if "edge" == mask_type:
node_mask = None
edge_mask = attr
if "node_and_mask" == mask_type:
node_mask = attr[0]
edge_mask = attr[1]
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
target: None,
) -> Explanation:
converted_model = to_captum_model(
model,
mask_type=self.mask_type,
output_idx=target,
)
attrib = self.captum_model(converted_model)
inputs, additional_forward_args = to_captum_input(
x, edge_index, mask_type=self.mask_type
)
attr = attrib.attribute(
inputs,
target=target,
additional_forward_args=additional_forward_args,
**self.kwargs,
)
node_mask, edge_mask = self.attr_to_tuple(attr, self.mask_type)
return Explanation(
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
)
if "__name__" == __main__:
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
y = torch.tensor([1], dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
batch = Batch().from_data_list([data])
class Model(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.conv = GCNConv(dim_in, dim_out)
def forward(self, batch):
x, edge_index, batch = batch.x, batch.edge_index, batch.batch
x, edge_index = self.conv(x, edge_index)
x = global_mean_pooling(x, edge_index, batch)
return x
model = Model(1,2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
for epoch in range(1, 2):
model.train()
optimizer.zero_grad()
log_logits = model(batch)
loss = F.mse_loss(out,torch.ones(2))
loss.backward()
optimizer.step()
expla = [LRP, DeepLift, DeepLiftShap, FeatureAblation,
FeaturePermutation, GradientShap, GuidedBackprop,
GuidedGradCam, InputXGradient, IntegratedGradient,
Lime, Occlusion, Saliency]
model = model.eval()
for captum_exp in expla:
try:
exp = FromCaptum(captum_exp)
attr = exp.forward(model,batch.x,batch.edge_index)
except Exception as e:
print(str(e))

View File

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from torch_geometric.data import Data, Batch, Dataset
from torch_geometric.utils import to_networkx
import networkx as nx
import numpy as np
import pandas as pd
class DatasetStatistic(object):
def __init__(self, dataset):
if not isinstance(dataset, Dataset):
raise ValueError(f'{dataset} needs to be an PyG dataset object')
self.dataset = dataset
if len(dataset) == 0:
self.task = 'node'
else:
self.task = 'graph'
self.study = None
def save_csv(self, path):

View File

@ -0,0 +1,28 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from inspect import getmembers, isfunction, signature
import networkx as nx
from docstring_parser import parse
functions_list = getmembers(nx, isfunction)
rt_type =[]
for func in functions_list:
name, f = func
docstring = parse(f.__doc__)
try:
rt = docstring.returns.type_name
rt_type.append(rt)
if rt == 'int' or rt == 'float' or rt=='bool' or rt=='boolean' or rt=='dictionary':
print(f'{name} : {rt}')
except AttributeError:
continue
print('int', rt_type.count('int'))
print('float', rt_type.count('float'))
print('bool', rt_type.count('bool'))
print('boolean', rt_type.count('boolean'))
print('dictionary', rt_type.count('dictionary'))
print('length func',len(functions_list))
print('length rt',len(rt_type))

View File

View File

@ -0,0 +1,228 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import types
from inspect import getmembers, isfunction, signature
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
__maps__ = [
"adamic_adar_index",
"approximate_current_flow_betweenness_centrality",
"average_clustering",
"average_degree_connectivity",
"average_neighbor_degree",
"average_node_connectivity",
"average_shortest_path_length",
"betweenness_centrality",
"bridges",
"closeness_centrality",
"clustering",
"cn_soundarajan_hopcroft",
"common_neighbor_centrality",
"communicability",
"communicability_betweenness_centrality",
"communicability_exp",
"connected_components",
"constraint",
"core_number",
"current_flow_betweenness_centrality",
"current_flow_closeness_centrality",
"cycle_basis",
"degree_assortativity_coefficient",
"degree_centrality",
"degree_mixing_matrix",
"degree_pearson_correlation_coefficient",
"diameter",
"dispersion",
"dominating_set",
"eccentricity",
"effective_size",
"eigenvector_centrality",
"estrada_index",
"generalized_degree",
"global_efficiency",
"global_reaching_centrality",
"graph_clique_number",
"harmonic_centrality",
"has_bridges",
"has_eulerian_path",
"hits",
"information_centrality",
"is_at_free",
"is_biconnected",
"is_bipartite",
"is_chordal",
"is_connected",
"is_directed_acyclic_graph",
"is_distance_regular",
"is_eulerian",
"is_forest",
"is_graphical",
"is_multigraphical",
"is_planar",
"is_pseudographical",
"is_regular",
"is_semieulerian",
"is_strongly_regular",
"is_tree",
"is_valid_degree_sequence_erdos_gallai",
"is_valid_degree_sequence_havel_hakimi",
"jaccard_coefficient",
"k_components",
"k_core",
"katz_centrality",
"load_centrality",
"minimum_cycle_basis",
"minimum_edge_cut",
"minimum_node_cut",
"node_clique_number",
"node_connectivity",
"node_degree_xy",
"number_connected_components",
"number_of_cliques",
"number_of_isolates",
"pagerank",
"periphery",
"preferential_attachment",
"reciprocity",
"resource_allocation_index",
"rich_club_coefficient",
"square_clustering",
"stoer_wagner",
"subgraph_centrality",
"subgraph_centrality_exp",
"transitivity",
"triangles",
"voterank",
"wiener_index",
"algebraic_connectivity",
"degree",
"density",
"normalized_laplacian_spectrum",
"number_of_edges",
"number_of_nodes",
"number_of_selfloops",
]
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx, isfunction)
maps = {}
for func in functions_list:
name, f = func
if name in __maps__:
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
"num_nodes_features",
"y",
]
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps
def __call__(self, data):
data_ = data.__copy__()
datahash = hash(data.__repr__)
stats = {}
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, func in v.items():
try:
val = func(_data_)
except:
val = None
if callable(val):
val = val()
if isinstance(val, types.GeneratorType):
try:
val = dict(val)
except:
val = None
if name == "hits":
val = val[0]
if name == "k_components":
for key, value in val.items():
val[key] = list(val[key][0])
stats[name] = self.convert(val)
stats["hash"] = datahash
return stats
def convert(self, val, K=4):
if type(val) == set:
val = list(val)
return val
elif isinstance(val, torch.Tensor):
val = val.cpu().numpy().tolist()
return val
elif isinstance(val, list):
for ind in range(len(val)):
item = val[ind]
if isinstance(item, list):
for subind in range(len(item)):
subitem = item[subind]
if type(subitem) == float:
item[subind] = round(subitem, K)
else:
if type(item) == float:
val[ind] = round(item, K)
return val
elif isinstance(val, dict):
for k, v in val.items():
if isinstance(v, dict):
for k1, v1 in v.items():
if type(v1) == float:
v[k1] = round(v1, K)
if type(v1) == set:
v[k1] = list(v1)
else:
if type(v) == float:
val[k] = round(v, K)
return val
elif isinstance(val, nx.classes.reportviews.DegreeView):
val = np.array(val).tolist()
return val
elif isinstance(val, nx.Graph):
val = nx.node_link_data(val)
return val
elif isinstance(val, np.ndarray):
val = np.around(val, decimals=K)
return val.tolist()
else:
return val

View File

View File

@ -0,0 +1,3 @@
class BaseExplaining(object):
def __init__(self,model,explainer_name:wq

View File

@ -0,0 +1 @@
from torch_geometric.nn.models.captum import CaptumModel