explaining_framework/stat/graph/graph_stat.py
2022-12-07 22:24:08 +01:00

229 lines
6.1 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import types
from inspect import getmembers, isfunction, signature
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
__maps__ = [
"adamic_adar_index",
"approximate_current_flow_betweenness_centrality",
"average_clustering",
"average_degree_connectivity",
"average_neighbor_degree",
"average_node_connectivity",
"average_shortest_path_length",
"betweenness_centrality",
"bridges",
"closeness_centrality",
"clustering",
"cn_soundarajan_hopcroft",
"common_neighbor_centrality",
"communicability",
"communicability_betweenness_centrality",
"communicability_exp",
"connected_components",
"constraint",
"core_number",
"current_flow_betweenness_centrality",
"current_flow_closeness_centrality",
"cycle_basis",
"degree_assortativity_coefficient",
"degree_centrality",
"degree_mixing_matrix",
"degree_pearson_correlation_coefficient",
"diameter",
"dispersion",
"dominating_set",
"eccentricity",
"effective_size",
"eigenvector_centrality",
"estrada_index",
"generalized_degree",
"global_efficiency",
"global_reaching_centrality",
"graph_clique_number",
"harmonic_centrality",
"has_bridges",
"has_eulerian_path",
"hits",
"information_centrality",
"is_at_free",
"is_biconnected",
"is_bipartite",
"is_chordal",
"is_connected",
"is_directed_acyclic_graph",
"is_distance_regular",
"is_eulerian",
"is_forest",
"is_graphical",
"is_multigraphical",
"is_planar",
"is_pseudographical",
"is_regular",
"is_semieulerian",
"is_strongly_regular",
"is_tree",
"is_valid_degree_sequence_erdos_gallai",
"is_valid_degree_sequence_havel_hakimi",
"jaccard_coefficient",
"k_components",
"k_core",
"katz_centrality",
"load_centrality",
"minimum_cycle_basis",
"minimum_edge_cut",
"minimum_node_cut",
"node_clique_number",
"node_connectivity",
"node_degree_xy",
"number_connected_components",
"number_of_cliques",
"number_of_isolates",
"pagerank",
"periphery",
"preferential_attachment",
"reciprocity",
"resource_allocation_index",
"rich_club_coefficient",
"square_clustering",
"stoer_wagner",
"subgraph_centrality",
"subgraph_centrality_exp",
"transitivity",
"triangles",
"voterank",
"wiener_index",
"algebraic_connectivity",
"degree",
"density",
"normalized_laplacian_spectrum",
"number_of_edges",
"number_of_nodes",
"number_of_selfloops",
]
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx, isfunction)
maps = {}
for func in functions_list:
name, f = func
if name in __maps__:
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
"num_nodes_features",
"y",
]
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps
def __call__(self, data):
data_ = data.__copy__()
datahash = hash(data.__repr__)
stats = {}
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, func in v.items():
try:
val = func(_data_)
except:
val = None
if callable(val):
val = val()
if isinstance(val, types.GeneratorType):
try:
val = dict(val)
except:
val = None
if name == "hits":
val = val[0]
if name == "k_components":
for key, value in val.items():
val[key] = list(val[key][0])
stats[name] = self.convert(val)
stats["hash"] = datahash
return stats
def convert(self, val, K=4):
if type(val) == set:
val = list(val)
return val
elif isinstance(val, torch.Tensor):
val = val.cpu().numpy().tolist()
return val
elif isinstance(val, list):
for ind in range(len(val)):
item = val[ind]
if isinstance(item, list):
for subind in range(len(item)):
subitem = item[subind]
if type(subitem) == float:
item[subind] = round(subitem, K)
else:
if type(item) == float:
val[ind] = round(item, K)
return val
elif isinstance(val, dict):
for k, v in val.items():
if isinstance(v, dict):
for k1, v1 in v.items():
if type(v1) == float:
v[k1] = round(v1, K)
if type(v1) == set:
v[k1] = list(v1)
else:
if type(v) == float:
val[k] = round(v, K)
return val
elif isinstance(val, nx.classes.reportviews.DegreeView):
val = np.array(val).tolist()
return val
elif isinstance(val, nx.Graph):
val = nx.node_link_data(val)
return val
elif isinstance(val, np.ndarray):
val = np.around(val, decimals=K)
return val.tolist()
else:
return val