229 lines
6.1 KiB
Python
229 lines
6.1 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import types
|
||
|
from inspect import getmembers, isfunction, signature
|
||
|
|
||
|
import networkx as nx
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch_geometric.data import Data
|
||
|
from torch_geometric.utils import to_networkx
|
||
|
|
||
|
__maps__ = [
|
||
|
"adamic_adar_index",
|
||
|
"approximate_current_flow_betweenness_centrality",
|
||
|
"average_clustering",
|
||
|
"average_degree_connectivity",
|
||
|
"average_neighbor_degree",
|
||
|
"average_node_connectivity",
|
||
|
"average_shortest_path_length",
|
||
|
"betweenness_centrality",
|
||
|
"bridges",
|
||
|
"closeness_centrality",
|
||
|
"clustering",
|
||
|
"cn_soundarajan_hopcroft",
|
||
|
"common_neighbor_centrality",
|
||
|
"communicability",
|
||
|
"communicability_betweenness_centrality",
|
||
|
"communicability_exp",
|
||
|
"connected_components",
|
||
|
"constraint",
|
||
|
"core_number",
|
||
|
"current_flow_betweenness_centrality",
|
||
|
"current_flow_closeness_centrality",
|
||
|
"cycle_basis",
|
||
|
"degree_assortativity_coefficient",
|
||
|
"degree_centrality",
|
||
|
"degree_mixing_matrix",
|
||
|
"degree_pearson_correlation_coefficient",
|
||
|
"diameter",
|
||
|
"dispersion",
|
||
|
"dominating_set",
|
||
|
"eccentricity",
|
||
|
"effective_size",
|
||
|
"eigenvector_centrality",
|
||
|
"estrada_index",
|
||
|
"generalized_degree",
|
||
|
"global_efficiency",
|
||
|
"global_reaching_centrality",
|
||
|
"graph_clique_number",
|
||
|
"harmonic_centrality",
|
||
|
"has_bridges",
|
||
|
"has_eulerian_path",
|
||
|
"hits",
|
||
|
"information_centrality",
|
||
|
"is_at_free",
|
||
|
"is_biconnected",
|
||
|
"is_bipartite",
|
||
|
"is_chordal",
|
||
|
"is_connected",
|
||
|
"is_directed_acyclic_graph",
|
||
|
"is_distance_regular",
|
||
|
"is_eulerian",
|
||
|
"is_forest",
|
||
|
"is_graphical",
|
||
|
"is_multigraphical",
|
||
|
"is_planar",
|
||
|
"is_pseudographical",
|
||
|
"is_regular",
|
||
|
"is_semieulerian",
|
||
|
"is_strongly_regular",
|
||
|
"is_tree",
|
||
|
"is_valid_degree_sequence_erdos_gallai",
|
||
|
"is_valid_degree_sequence_havel_hakimi",
|
||
|
"jaccard_coefficient",
|
||
|
"k_components",
|
||
|
"k_core",
|
||
|
"katz_centrality",
|
||
|
"load_centrality",
|
||
|
"minimum_cycle_basis",
|
||
|
"minimum_edge_cut",
|
||
|
"minimum_node_cut",
|
||
|
"node_clique_number",
|
||
|
"node_connectivity",
|
||
|
"node_degree_xy",
|
||
|
"number_connected_components",
|
||
|
"number_of_cliques",
|
||
|
"number_of_isolates",
|
||
|
"pagerank",
|
||
|
"periphery",
|
||
|
"preferential_attachment",
|
||
|
"reciprocity",
|
||
|
"resource_allocation_index",
|
||
|
"rich_club_coefficient",
|
||
|
"square_clustering",
|
||
|
"stoer_wagner",
|
||
|
"subgraph_centrality",
|
||
|
"subgraph_centrality_exp",
|
||
|
"transitivity",
|
||
|
"triangles",
|
||
|
"voterank",
|
||
|
"wiener_index",
|
||
|
"algebraic_connectivity",
|
||
|
"degree",
|
||
|
"density",
|
||
|
"normalized_laplacian_spectrum",
|
||
|
"number_of_edges",
|
||
|
"number_of_nodes",
|
||
|
"number_of_selfloops",
|
||
|
]
|
||
|
|
||
|
|
||
|
class GraphStat(object):
|
||
|
def __init__(self):
|
||
|
self.maps = {
|
||
|
"networkx": self.available_map_networkx(),
|
||
|
"torch_geometric": self.available_map_torch_geometric(),
|
||
|
}
|
||
|
|
||
|
def available_map_networkx(self):
|
||
|
functions_list = getmembers(nx, isfunction)
|
||
|
|
||
|
maps = {}
|
||
|
for func in functions_list:
|
||
|
name, f = func
|
||
|
if name in __maps__:
|
||
|
maps[name] = f
|
||
|
return maps
|
||
|
|
||
|
def available_map_torch_geometric(self):
|
||
|
names = [
|
||
|
"num_nodes",
|
||
|
"num_edges",
|
||
|
"has_self_loops",
|
||
|
"has_isolated_nodes",
|
||
|
"num_nodes_features",
|
||
|
"y",
|
||
|
]
|
||
|
maps = {
|
||
|
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||
|
for name in names
|
||
|
}
|
||
|
return maps
|
||
|
|
||
|
def __call__(self, data):
|
||
|
data_ = data.__copy__()
|
||
|
datahash = hash(data.__repr__)
|
||
|
stats = {}
|
||
|
for k, v in self.maps.items():
|
||
|
if k == "networkx":
|
||
|
_data_ = to_networkx(data)
|
||
|
_data_ = _data_.to_undirected()
|
||
|
elif k == "torch_geometric":
|
||
|
_data_ = data.__copy__()
|
||
|
for name, func in v.items():
|
||
|
try:
|
||
|
val = func(_data_)
|
||
|
except:
|
||
|
val = None
|
||
|
if callable(val):
|
||
|
val = val()
|
||
|
if isinstance(val, types.GeneratorType):
|
||
|
try:
|
||
|
val = dict(val)
|
||
|
except:
|
||
|
val = None
|
||
|
if name == "hits":
|
||
|
val = val[0]
|
||
|
if name == "k_components":
|
||
|
for key, value in val.items():
|
||
|
val[key] = list(val[key][0])
|
||
|
|
||
|
stats[name] = self.convert(val)
|
||
|
stats["hash"] = datahash
|
||
|
return stats
|
||
|
|
||
|
def convert(self, val, K=4):
|
||
|
|
||
|
if type(val) == set:
|
||
|
val = list(val)
|
||
|
return val
|
||
|
|
||
|
elif isinstance(val, torch.Tensor):
|
||
|
val = val.cpu().numpy().tolist()
|
||
|
return val
|
||
|
|
||
|
elif isinstance(val, list):
|
||
|
for ind in range(len(val)):
|
||
|
item = val[ind]
|
||
|
if isinstance(item, list):
|
||
|
for subind in range(len(item)):
|
||
|
subitem = item[subind]
|
||
|
if type(subitem) == float:
|
||
|
item[subind] = round(subitem, K)
|
||
|
|
||
|
else:
|
||
|
if type(item) == float:
|
||
|
val[ind] = round(item, K)
|
||
|
return val
|
||
|
|
||
|
elif isinstance(val, dict):
|
||
|
for k, v in val.items():
|
||
|
if isinstance(v, dict):
|
||
|
for k1, v1 in v.items():
|
||
|
if type(v1) == float:
|
||
|
v[k1] = round(v1, K)
|
||
|
if type(v1) == set:
|
||
|
v[k1] = list(v1)
|
||
|
|
||
|
else:
|
||
|
if type(v) == float:
|
||
|
val[k] = round(v, K)
|
||
|
return val
|
||
|
|
||
|
elif isinstance(val, nx.classes.reportviews.DegreeView):
|
||
|
val = np.array(val).tolist()
|
||
|
return val
|
||
|
|
||
|
elif isinstance(val, nx.Graph):
|
||
|
val = nx.node_link_data(val)
|
||
|
return val
|
||
|
|
||
|
elif isinstance(val, np.ndarray):
|
||
|
val = np.around(val, decimals=K)
|
||
|
return val.tolist()
|
||
|
|
||
|
else:
|
||
|
return val
|