New fixes + reformat
This commit is contained in:
parent
09a20b891f
commit
5d2cacfa05
11 changed files with 232 additions and 125 deletions
0
stat/graph/__init__.py
Normal file
0
stat/graph/__init__.py
Normal file
228
stat/graph/graph_stat.py
Normal file
228
stat/graph/graph_stat.py
Normal file
|
@ -0,0 +1,228 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import types
|
||||||
|
from inspect import getmembers, isfunction, signature
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.utils import to_networkx
|
||||||
|
|
||||||
|
__maps__ = [
|
||||||
|
"adamic_adar_index",
|
||||||
|
"approximate_current_flow_betweenness_centrality",
|
||||||
|
"average_clustering",
|
||||||
|
"average_degree_connectivity",
|
||||||
|
"average_neighbor_degree",
|
||||||
|
"average_node_connectivity",
|
||||||
|
"average_shortest_path_length",
|
||||||
|
"betweenness_centrality",
|
||||||
|
"bridges",
|
||||||
|
"closeness_centrality",
|
||||||
|
"clustering",
|
||||||
|
"cn_soundarajan_hopcroft",
|
||||||
|
"common_neighbor_centrality",
|
||||||
|
"communicability",
|
||||||
|
"communicability_betweenness_centrality",
|
||||||
|
"communicability_exp",
|
||||||
|
"connected_components",
|
||||||
|
"constraint",
|
||||||
|
"core_number",
|
||||||
|
"current_flow_betweenness_centrality",
|
||||||
|
"current_flow_closeness_centrality",
|
||||||
|
"cycle_basis",
|
||||||
|
"degree_assortativity_coefficient",
|
||||||
|
"degree_centrality",
|
||||||
|
"degree_mixing_matrix",
|
||||||
|
"degree_pearson_correlation_coefficient",
|
||||||
|
"diameter",
|
||||||
|
"dispersion",
|
||||||
|
"dominating_set",
|
||||||
|
"eccentricity",
|
||||||
|
"effective_size",
|
||||||
|
"eigenvector_centrality",
|
||||||
|
"estrada_index",
|
||||||
|
"generalized_degree",
|
||||||
|
"global_efficiency",
|
||||||
|
"global_reaching_centrality",
|
||||||
|
"graph_clique_number",
|
||||||
|
"harmonic_centrality",
|
||||||
|
"has_bridges",
|
||||||
|
"has_eulerian_path",
|
||||||
|
"hits",
|
||||||
|
"information_centrality",
|
||||||
|
"is_at_free",
|
||||||
|
"is_biconnected",
|
||||||
|
"is_bipartite",
|
||||||
|
"is_chordal",
|
||||||
|
"is_connected",
|
||||||
|
"is_directed_acyclic_graph",
|
||||||
|
"is_distance_regular",
|
||||||
|
"is_eulerian",
|
||||||
|
"is_forest",
|
||||||
|
"is_graphical",
|
||||||
|
"is_multigraphical",
|
||||||
|
"is_planar",
|
||||||
|
"is_pseudographical",
|
||||||
|
"is_regular",
|
||||||
|
"is_semieulerian",
|
||||||
|
"is_strongly_regular",
|
||||||
|
"is_tree",
|
||||||
|
"is_valid_degree_sequence_erdos_gallai",
|
||||||
|
"is_valid_degree_sequence_havel_hakimi",
|
||||||
|
"jaccard_coefficient",
|
||||||
|
"k_components",
|
||||||
|
"k_core",
|
||||||
|
"katz_centrality",
|
||||||
|
"load_centrality",
|
||||||
|
"minimum_cycle_basis",
|
||||||
|
"minimum_edge_cut",
|
||||||
|
"minimum_node_cut",
|
||||||
|
"node_clique_number",
|
||||||
|
"node_connectivity",
|
||||||
|
"node_degree_xy",
|
||||||
|
"number_connected_components",
|
||||||
|
"number_of_cliques",
|
||||||
|
"number_of_isolates",
|
||||||
|
"pagerank",
|
||||||
|
"periphery",
|
||||||
|
"preferential_attachment",
|
||||||
|
"reciprocity",
|
||||||
|
"resource_allocation_index",
|
||||||
|
"rich_club_coefficient",
|
||||||
|
"square_clustering",
|
||||||
|
"stoer_wagner",
|
||||||
|
"subgraph_centrality",
|
||||||
|
"subgraph_centrality_exp",
|
||||||
|
"transitivity",
|
||||||
|
"triangles",
|
||||||
|
"voterank",
|
||||||
|
"wiener_index",
|
||||||
|
"algebraic_connectivity",
|
||||||
|
"degree",
|
||||||
|
"density",
|
||||||
|
"normalized_laplacian_spectrum",
|
||||||
|
"number_of_edges",
|
||||||
|
"number_of_nodes",
|
||||||
|
"number_of_selfloops",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphStat(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.maps = {
|
||||||
|
"networkx": self.available_map_networkx(),
|
||||||
|
"torch_geometric": self.available_map_torch_geometric(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def available_map_networkx(self):
|
||||||
|
functions_list = getmembers(nx, isfunction)
|
||||||
|
|
||||||
|
maps = {}
|
||||||
|
for func in functions_list:
|
||||||
|
name, f = func
|
||||||
|
if name in __maps__:
|
||||||
|
maps[name] = f
|
||||||
|
return maps
|
||||||
|
|
||||||
|
def available_map_torch_geometric(self):
|
||||||
|
names = [
|
||||||
|
"num_nodes",
|
||||||
|
"num_edges",
|
||||||
|
"has_self_loops",
|
||||||
|
"has_isolated_nodes",
|
||||||
|
"num_nodes_features",
|
||||||
|
"y",
|
||||||
|
]
|
||||||
|
maps = {
|
||||||
|
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||||
|
for name in names
|
||||||
|
}
|
||||||
|
return maps
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
data_ = data.__copy__()
|
||||||
|
datahash = hash(data.__repr__)
|
||||||
|
stats = {}
|
||||||
|
for k, v in self.maps.items():
|
||||||
|
if k == "networkx":
|
||||||
|
_data_ = to_networkx(data)
|
||||||
|
_data_ = _data_.to_undirected()
|
||||||
|
elif k == "torch_geometric":
|
||||||
|
_data_ = data.__copy__()
|
||||||
|
for name, func in v.items():
|
||||||
|
try:
|
||||||
|
val = func(_data_)
|
||||||
|
except:
|
||||||
|
val = None
|
||||||
|
if callable(val):
|
||||||
|
val = val()
|
||||||
|
if isinstance(val, types.GeneratorType):
|
||||||
|
try:
|
||||||
|
val = dict(val)
|
||||||
|
except:
|
||||||
|
val = None
|
||||||
|
if name == "hits":
|
||||||
|
val = val[0]
|
||||||
|
if name == "k_components":
|
||||||
|
for key, value in val.items():
|
||||||
|
val[key] = list(val[key][0])
|
||||||
|
|
||||||
|
stats[name] = self.convert(val)
|
||||||
|
stats["hash"] = datahash
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def convert(self, val, K=4):
|
||||||
|
|
||||||
|
if type(val) == set:
|
||||||
|
val = list(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, torch.Tensor):
|
||||||
|
val = val.cpu().numpy().tolist()
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, list):
|
||||||
|
for ind in range(len(val)):
|
||||||
|
item = val[ind]
|
||||||
|
if isinstance(item, list):
|
||||||
|
for subind in range(len(item)):
|
||||||
|
subitem = item[subind]
|
||||||
|
if type(subitem) == float:
|
||||||
|
item[subind] = round(subitem, K)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if type(item) == float:
|
||||||
|
val[ind] = round(item, K)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, dict):
|
||||||
|
for k, v in val.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
for k1, v1 in v.items():
|
||||||
|
if type(v1) == float:
|
||||||
|
v[k1] = round(v1, K)
|
||||||
|
if type(v1) == set:
|
||||||
|
v[k1] = list(v1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if type(v) == float:
|
||||||
|
val[k] = round(v, K)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, nx.classes.reportviews.DegreeView):
|
||||||
|
val = np.array(val).tolist()
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, nx.Graph):
|
||||||
|
val = nx.node_link_data(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
elif isinstance(val, np.ndarray):
|
||||||
|
val = np.around(val, decimals=K)
|
||||||
|
return val.tolist()
|
||||||
|
|
||||||
|
else:
|
||||||
|
return val
|
0
tool/__init__.py
Normal file
0
tool/__init__.py
Normal file
0
tool/explainer/__init__.py
Normal file
0
tool/explainer/__init__.py
Normal file
3
tool/explainer/base.py
Normal file
3
tool/explainer/base.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
class BaseExplaining(object):
|
||||||
|
def __init__(self,model,explainer_name:wq
|
||||||
|
|
1
tool/explainer/from_captum.py
Normal file
1
tool/explainer/from_captum.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from torch_geometric.nn.models.captum import CaptumModel
|
0
tool/graphgym/__init__.py
Normal file
0
tool/graphgym/__init__.py
Normal file
0
tool/visualizer/__init__.py
Normal file
0
tool/visualizer/__init__.py
Normal file
125
utils/stat.py
125
utils/stat.py
|
@ -1,125 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import multiprocessing as mp
|
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
from inspect import getmembers, isfunction, signature
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
from torch_geometric.data import Data
|
|
||||||
from torch_geometric.utils import to_networkx
|
|
||||||
|
|
||||||
with open("mapping_nx.txt", "r") as file:
|
|
||||||
BLACK_LIST = [line.rstrip() for line in file]
|
|
||||||
|
|
||||||
|
|
||||||
class GraphStat(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.maps = {
|
|
||||||
"networkx": self.available_map_networkx(),
|
|
||||||
"torch_geometric": self.available_map_torch_geometric(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def available_map_networkx(self):
|
|
||||||
functions_list = getmembers(nx.algorithms, isfunction)
|
|
||||||
MANUALLY_ADDED = [
|
|
||||||
"algebraic_connectivity",
|
|
||||||
"adjacency_spectrum",
|
|
||||||
"degree",
|
|
||||||
"density",
|
|
||||||
"laplacian_spectrum",
|
|
||||||
"normalized_laplacian_spectrum",
|
|
||||||
"number_of_selfloops",
|
|
||||||
"number_of_edges",
|
|
||||||
"number_of_nodes",
|
|
||||||
]
|
|
||||||
MANUALLY_ADDED_LIST = [
|
|
||||||
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
|
|
||||||
]
|
|
||||||
functions_list = functions_list + MANUALLY_ADDED_LIST
|
|
||||||
maps = {}
|
|
||||||
for func in functions_list:
|
|
||||||
name, f = func
|
|
||||||
if (
|
|
||||||
name in BLACK_LIST
|
|
||||||
or name == "recursive_simple_cycles"
|
|
||||||
or "triad" in name
|
|
||||||
or "weisfeiler" in name
|
|
||||||
or "dfs" in name
|
|
||||||
or "trophic" in name
|
|
||||||
or "recursive" in name
|
|
||||||
or "scipy" in name
|
|
||||||
or "numpy" in name
|
|
||||||
or "sigma" in name
|
|
||||||
or "omega" in name
|
|
||||||
or "all_" in name
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
maps[name] = f
|
|
||||||
return maps
|
|
||||||
|
|
||||||
def available_map_torch_geometric(self):
|
|
||||||
names = [
|
|
||||||
"num_nodes",
|
|
||||||
"num_edges",
|
|
||||||
"has_self_loops",
|
|
||||||
"has_isolated_nodes",
|
|
||||||
"num_nodes_features",
|
|
||||||
"y",
|
|
||||||
]
|
|
||||||
maps = {
|
|
||||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
|
||||||
for name in names
|
|
||||||
}
|
|
||||||
return maps
|
|
||||||
|
|
||||||
def __call__(self, data):
|
|
||||||
data_ = data.__copy__()
|
|
||||||
stats = {}
|
|
||||||
for k, v in self.maps.items():
|
|
||||||
if k == "networkx":
|
|
||||||
_data_ = to_networkx(data)
|
|
||||||
_data_ = _data_.to_undirected()
|
|
||||||
elif k == "torch_geometric":
|
|
||||||
_data_ = data.__copy__()
|
|
||||||
for name, f in v.items():
|
|
||||||
if f is None:
|
|
||||||
stats[name] = None
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
t0 = time.time()
|
|
||||||
val = f(_data_)
|
|
||||||
t1 = time.time()
|
|
||||||
delta = t1 - t0
|
|
||||||
except Exception as e:
|
|
||||||
print(name, e)
|
|
||||||
with open(f"{name}.txt", "w") as f:
|
|
||||||
f.write(str(e))
|
|
||||||
# print(name, round(delta, 4))
|
|
||||||
# if callable(val) and k == "torch_geometric":
|
|
||||||
# val = val()
|
|
||||||
# if isinstance(val, types.GeneratorType):
|
|
||||||
# val = list(val)
|
|
||||||
# stats[name] = val
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
from torch_geometric.datasets import KarateClub, Planetoid
|
|
||||||
|
|
||||||
d = Planetoid(root="/tmp/", name="Cora")
|
|
||||||
# d = KarateClub()
|
|
||||||
a = d[0]
|
|
||||||
st = GraphStat()
|
|
||||||
stat = st(a)
|
|
||||||
for k, v in stat.items():
|
|
||||||
print("---------")
|
|
||||||
print("Name:", k)
|
|
||||||
print("Type:", type(v))
|
|
||||||
print("Val:", v)
|
|
||||||
print("---------")
|
|
Loading…
Add table
Reference in a new issue