#!/usr/bin/env python # -*- coding: utf-8 -*- import types from inspect import getmembers, isfunction, signature import networkx as nx import numpy as np import torch from torch_geometric.data import Data from torch_geometric.utils import to_networkx __maps__ = [ "adamic_adar_index", "approximate_current_flow_betweenness_centrality", "average_clustering", "average_degree_connectivity", "average_neighbor_degree", "average_node_connectivity", "average_shortest_path_length", "betweenness_centrality", "bridges", "closeness_centrality", "clustering", "cn_soundarajan_hopcroft", "common_neighbor_centrality", "communicability", "communicability_betweenness_centrality", "communicability_exp", "connected_components", "constraint", "core_number", "current_flow_betweenness_centrality", "current_flow_closeness_centrality", "cycle_basis", "degree_assortativity_coefficient", "degree_centrality", "degree_mixing_matrix", "degree_pearson_correlation_coefficient", "diameter", "dispersion", "dominating_set", "eccentricity", "effective_size", "eigenvector_centrality", "estrada_index", "generalized_degree", "global_efficiency", "global_reaching_centrality", "graph_clique_number", "harmonic_centrality", "has_bridges", "has_eulerian_path", "hits", "information_centrality", "is_at_free", "is_biconnected", "is_bipartite", "is_chordal", "is_connected", "is_directed_acyclic_graph", "is_distance_regular", "is_eulerian", "is_forest", "is_graphical", "is_multigraphical", "is_planar", "is_pseudographical", "is_regular", "is_semieulerian", "is_strongly_regular", "is_tree", "is_valid_degree_sequence_erdos_gallai", "is_valid_degree_sequence_havel_hakimi", "jaccard_coefficient", "k_components", "k_core", "katz_centrality", "load_centrality", "minimum_cycle_basis", "minimum_edge_cut", "minimum_node_cut", "node_clique_number", "node_connectivity", "node_degree_xy", "number_connected_components", "number_of_cliques", "number_of_isolates", "pagerank", "periphery", "preferential_attachment", "reciprocity", "resource_allocation_index", "rich_club_coefficient", "square_clustering", "stoer_wagner", "subgraph_centrality", "subgraph_centrality_exp", "transitivity", "triangles", "voterank", "wiener_index", "algebraic_connectivity", "degree", "density", "normalized_laplacian_spectrum", "number_of_edges", "number_of_nodes", "number_of_selfloops", ] class GraphStat(object): def __init__(self): self.maps = { "networkx": self.available_map_networkx(), "torch_geometric": self.available_map_torch_geometric(), } def available_map_networkx(self): functions_list = getmembers(nx, isfunction) maps = {} for func in functions_list: name, f = func if name in __maps__: maps[name] = f return maps def available_map_torch_geometric(self): names = [ "num_nodes", "num_edges", "has_self_loops", "has_isolated_nodes", "num_nodes_features", "y", ] maps = { name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None for name in names } return maps def __call__(self, data): data_ = data.__copy__() datahash = hash(data.__repr__) stats = {} for k, v in self.maps.items(): if k == "networkx": _data_ = to_networkx(data) _data_ = _data_.to_undirected() elif k == "torch_geometric": _data_ = data.__copy__() for name, func in v.items(): try: val = func(_data_) except: val = None if callable(val): val = val() if isinstance(val, types.GeneratorType): try: val = dict(val) except: val = None if name == "hits": val = val[0] if name == "k_components": for key, value in val.items(): val[key] = list(val[key][0]) stats[name] = self.convert(val) stats["hash"] = datahash return stats def convert(self, val, K=4): if type(val) == set: val = list(val) return val elif isinstance(val, torch.Tensor): val = val.cpu().numpy().tolist() return val elif isinstance(val, list): for ind in range(len(val)): item = val[ind] if isinstance(item, list): for subind in range(len(item)): subitem = item[subind] if type(subitem) == float: item[subind] = round(subitem, K) else: if type(item) == float: val[ind] = round(item, K) return val elif isinstance(val, dict): for k, v in val.items(): if isinstance(v, dict): for k1, v1 in v.items(): if type(v1) == float: v[k1] = round(v1, K) if type(v1) == set: v[k1] = list(v1) else: if type(v) == float: val[k] = round(v, K) return val elif isinstance(val, nx.classes.reportviews.DegreeView): val = np.array(val).tolist() return val elif isinstance(val, nx.Graph): val = nx.node_link_data(val) return val elif isinstance(val, np.ndarray): val = np.around(val, decimals=K) return val.tolist() else: return val