New fixes + reformat
This commit is contained in:
parent
09a20b891f
commit
5d2cacfa05
0
stat/graph/__init__.py
Normal file
0
stat/graph/__init__.py
Normal file
228
stat/graph/graph_stat.py
Normal file
228
stat/graph/graph_stat.py
Normal file
@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import types
|
||||
from inspect import getmembers, isfunction, signature
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.utils import to_networkx
|
||||
|
||||
__maps__ = [
|
||||
"adamic_adar_index",
|
||||
"approximate_current_flow_betweenness_centrality",
|
||||
"average_clustering",
|
||||
"average_degree_connectivity",
|
||||
"average_neighbor_degree",
|
||||
"average_node_connectivity",
|
||||
"average_shortest_path_length",
|
||||
"betweenness_centrality",
|
||||
"bridges",
|
||||
"closeness_centrality",
|
||||
"clustering",
|
||||
"cn_soundarajan_hopcroft",
|
||||
"common_neighbor_centrality",
|
||||
"communicability",
|
||||
"communicability_betweenness_centrality",
|
||||
"communicability_exp",
|
||||
"connected_components",
|
||||
"constraint",
|
||||
"core_number",
|
||||
"current_flow_betweenness_centrality",
|
||||
"current_flow_closeness_centrality",
|
||||
"cycle_basis",
|
||||
"degree_assortativity_coefficient",
|
||||
"degree_centrality",
|
||||
"degree_mixing_matrix",
|
||||
"degree_pearson_correlation_coefficient",
|
||||
"diameter",
|
||||
"dispersion",
|
||||
"dominating_set",
|
||||
"eccentricity",
|
||||
"effective_size",
|
||||
"eigenvector_centrality",
|
||||
"estrada_index",
|
||||
"generalized_degree",
|
||||
"global_efficiency",
|
||||
"global_reaching_centrality",
|
||||
"graph_clique_number",
|
||||
"harmonic_centrality",
|
||||
"has_bridges",
|
||||
"has_eulerian_path",
|
||||
"hits",
|
||||
"information_centrality",
|
||||
"is_at_free",
|
||||
"is_biconnected",
|
||||
"is_bipartite",
|
||||
"is_chordal",
|
||||
"is_connected",
|
||||
"is_directed_acyclic_graph",
|
||||
"is_distance_regular",
|
||||
"is_eulerian",
|
||||
"is_forest",
|
||||
"is_graphical",
|
||||
"is_multigraphical",
|
||||
"is_planar",
|
||||
"is_pseudographical",
|
||||
"is_regular",
|
||||
"is_semieulerian",
|
||||
"is_strongly_regular",
|
||||
"is_tree",
|
||||
"is_valid_degree_sequence_erdos_gallai",
|
||||
"is_valid_degree_sequence_havel_hakimi",
|
||||
"jaccard_coefficient",
|
||||
"k_components",
|
||||
"k_core",
|
||||
"katz_centrality",
|
||||
"load_centrality",
|
||||
"minimum_cycle_basis",
|
||||
"minimum_edge_cut",
|
||||
"minimum_node_cut",
|
||||
"node_clique_number",
|
||||
"node_connectivity",
|
||||
"node_degree_xy",
|
||||
"number_connected_components",
|
||||
"number_of_cliques",
|
||||
"number_of_isolates",
|
||||
"pagerank",
|
||||
"periphery",
|
||||
"preferential_attachment",
|
||||
"reciprocity",
|
||||
"resource_allocation_index",
|
||||
"rich_club_coefficient",
|
||||
"square_clustering",
|
||||
"stoer_wagner",
|
||||
"subgraph_centrality",
|
||||
"subgraph_centrality_exp",
|
||||
"transitivity",
|
||||
"triangles",
|
||||
"voterank",
|
||||
"wiener_index",
|
||||
"algebraic_connectivity",
|
||||
"degree",
|
||||
"density",
|
||||
"normalized_laplacian_spectrum",
|
||||
"number_of_edges",
|
||||
"number_of_nodes",
|
||||
"number_of_selfloops",
|
||||
]
|
||||
|
||||
|
||||
class GraphStat(object):
|
||||
def __init__(self):
|
||||
self.maps = {
|
||||
"networkx": self.available_map_networkx(),
|
||||
"torch_geometric": self.available_map_torch_geometric(),
|
||||
}
|
||||
|
||||
def available_map_networkx(self):
|
||||
functions_list = getmembers(nx, isfunction)
|
||||
|
||||
maps = {}
|
||||
for func in functions_list:
|
||||
name, f = func
|
||||
if name in __maps__:
|
||||
maps[name] = f
|
||||
return maps
|
||||
|
||||
def available_map_torch_geometric(self):
|
||||
names = [
|
||||
"num_nodes",
|
||||
"num_edges",
|
||||
"has_self_loops",
|
||||
"has_isolated_nodes",
|
||||
"num_nodes_features",
|
||||
"y",
|
||||
]
|
||||
maps = {
|
||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||
for name in names
|
||||
}
|
||||
return maps
|
||||
|
||||
def __call__(self, data):
|
||||
data_ = data.__copy__()
|
||||
datahash = hash(data.__repr__)
|
||||
stats = {}
|
||||
for k, v in self.maps.items():
|
||||
if k == "networkx":
|
||||
_data_ = to_networkx(data)
|
||||
_data_ = _data_.to_undirected()
|
||||
elif k == "torch_geometric":
|
||||
_data_ = data.__copy__()
|
||||
for name, func in v.items():
|
||||
try:
|
||||
val = func(_data_)
|
||||
except:
|
||||
val = None
|
||||
if callable(val):
|
||||
val = val()
|
||||
if isinstance(val, types.GeneratorType):
|
||||
try:
|
||||
val = dict(val)
|
||||
except:
|
||||
val = None
|
||||
if name == "hits":
|
||||
val = val[0]
|
||||
if name == "k_components":
|
||||
for key, value in val.items():
|
||||
val[key] = list(val[key][0])
|
||||
|
||||
stats[name] = self.convert(val)
|
||||
stats["hash"] = datahash
|
||||
return stats
|
||||
|
||||
def convert(self, val, K=4):
|
||||
|
||||
if type(val) == set:
|
||||
val = list(val)
|
||||
return val
|
||||
|
||||
elif isinstance(val, torch.Tensor):
|
||||
val = val.cpu().numpy().tolist()
|
||||
return val
|
||||
|
||||
elif isinstance(val, list):
|
||||
for ind in range(len(val)):
|
||||
item = val[ind]
|
||||
if isinstance(item, list):
|
||||
for subind in range(len(item)):
|
||||
subitem = item[subind]
|
||||
if type(subitem) == float:
|
||||
item[subind] = round(subitem, K)
|
||||
|
||||
else:
|
||||
if type(item) == float:
|
||||
val[ind] = round(item, K)
|
||||
return val
|
||||
|
||||
elif isinstance(val, dict):
|
||||
for k, v in val.items():
|
||||
if isinstance(v, dict):
|
||||
for k1, v1 in v.items():
|
||||
if type(v1) == float:
|
||||
v[k1] = round(v1, K)
|
||||
if type(v1) == set:
|
||||
v[k1] = list(v1)
|
||||
|
||||
else:
|
||||
if type(v) == float:
|
||||
val[k] = round(v, K)
|
||||
return val
|
||||
|
||||
elif isinstance(val, nx.classes.reportviews.DegreeView):
|
||||
val = np.array(val).tolist()
|
||||
return val
|
||||
|
||||
elif isinstance(val, nx.Graph):
|
||||
val = nx.node_link_data(val)
|
||||
return val
|
||||
|
||||
elif isinstance(val, np.ndarray):
|
||||
val = np.around(val, decimals=K)
|
||||
return val.tolist()
|
||||
|
||||
else:
|
||||
return val
|
0
tool/__init__.py
Normal file
0
tool/__init__.py
Normal file
0
tool/explainer/__init__.py
Normal file
0
tool/explainer/__init__.py
Normal file
3
tool/explainer/base.py
Normal file
3
tool/explainer/base.py
Normal file
@ -0,0 +1,3 @@
|
||||
class BaseExplaining(object):
|
||||
def __init__(self,model,explainer_name:wq
|
||||
|
1
tool/explainer/from_captum.py
Normal file
1
tool/explainer/from_captum.py
Normal file
@ -0,0 +1 @@
|
||||
from torch_geometric.nn.models.captum import CaptumModel
|
0
tool/graphgym/__init__.py
Normal file
0
tool/graphgym/__init__.py
Normal file
0
tool/visualizer/__init__.py
Normal file
0
tool/visualizer/__init__.py
Normal file
125
utils/stat.py
125
utils/stat.py
@ -1,125 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from inspect import getmembers, isfunction, signature
|
||||
|
||||
import networkx as nx
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.utils import to_networkx
|
||||
|
||||
with open("mapping_nx.txt", "r") as file:
|
||||
BLACK_LIST = [line.rstrip() for line in file]
|
||||
|
||||
|
||||
class GraphStat(object):
|
||||
def __init__(self):
|
||||
self.maps = {
|
||||
"networkx": self.available_map_networkx(),
|
||||
"torch_geometric": self.available_map_torch_geometric(),
|
||||
}
|
||||
|
||||
def available_map_networkx(self):
|
||||
functions_list = getmembers(nx.algorithms, isfunction)
|
||||
MANUALLY_ADDED = [
|
||||
"algebraic_connectivity",
|
||||
"adjacency_spectrum",
|
||||
"degree",
|
||||
"density",
|
||||
"laplacian_spectrum",
|
||||
"normalized_laplacian_spectrum",
|
||||
"number_of_selfloops",
|
||||
"number_of_edges",
|
||||
"number_of_nodes",
|
||||
]
|
||||
MANUALLY_ADDED_LIST = [
|
||||
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
|
||||
]
|
||||
functions_list = functions_list + MANUALLY_ADDED_LIST
|
||||
maps = {}
|
||||
for func in functions_list:
|
||||
name, f = func
|
||||
if (
|
||||
name in BLACK_LIST
|
||||
or name == "recursive_simple_cycles"
|
||||
or "triad" in name
|
||||
or "weisfeiler" in name
|
||||
or "dfs" in name
|
||||
or "trophic" in name
|
||||
or "recursive" in name
|
||||
or "scipy" in name
|
||||
or "numpy" in name
|
||||
or "sigma" in name
|
||||
or "omega" in name
|
||||
or "all_" in name
|
||||
):
|
||||
continue
|
||||
else:
|
||||
maps[name] = f
|
||||
return maps
|
||||
|
||||
def available_map_torch_geometric(self):
|
||||
names = [
|
||||
"num_nodes",
|
||||
"num_edges",
|
||||
"has_self_loops",
|
||||
"has_isolated_nodes",
|
||||
"num_nodes_features",
|
||||
"y",
|
||||
]
|
||||
maps = {
|
||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||
for name in names
|
||||
}
|
||||
return maps
|
||||
|
||||
def __call__(self, data):
|
||||
data_ = data.__copy__()
|
||||
stats = {}
|
||||
for k, v in self.maps.items():
|
||||
if k == "networkx":
|
||||
_data_ = to_networkx(data)
|
||||
_data_ = _data_.to_undirected()
|
||||
elif k == "torch_geometric":
|
||||
_data_ = data.__copy__()
|
||||
for name, f in v.items():
|
||||
if f is None:
|
||||
stats[name] = None
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
t0 = time.time()
|
||||
val = f(_data_)
|
||||
t1 = time.time()
|
||||
delta = t1 - t0
|
||||
except Exception as e:
|
||||
print(name, e)
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(str(e))
|
||||
# print(name, round(delta, 4))
|
||||
# if callable(val) and k == "torch_geometric":
|
||||
# val = val()
|
||||
# if isinstance(val, types.GeneratorType):
|
||||
# val = list(val)
|
||||
# stats[name] = val
|
||||
return stats
|
||||
|
||||
|
||||
from torch_geometric.datasets import KarateClub, Planetoid
|
||||
|
||||
d = Planetoid(root="/tmp/", name="Cora")
|
||||
# d = KarateClub()
|
||||
a = d[0]
|
||||
st = GraphStat()
|
||||
stat = st(a)
|
||||
for k, v in stat.items():
|
||||
print("---------")
|
||||
print("Name:", k)
|
||||
print("Type:", type(v))
|
||||
print("Val:", v)
|
||||
print("---------")
|
Loading…
Reference in New Issue
Block a user