New fixes + reformat

This commit is contained in:
araison 2022-12-07 22:24:08 +01:00
parent 09a20b891f
commit 5d2cacfa05
11 changed files with 232 additions and 125 deletions

0
stat/graph/__init__.py Normal file
View File

228
stat/graph/graph_stat.py Normal file
View File

@ -0,0 +1,228 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import types
from inspect import getmembers, isfunction, signature
import networkx as nx
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
__maps__ = [
"adamic_adar_index",
"approximate_current_flow_betweenness_centrality",
"average_clustering",
"average_degree_connectivity",
"average_neighbor_degree",
"average_node_connectivity",
"average_shortest_path_length",
"betweenness_centrality",
"bridges",
"closeness_centrality",
"clustering",
"cn_soundarajan_hopcroft",
"common_neighbor_centrality",
"communicability",
"communicability_betweenness_centrality",
"communicability_exp",
"connected_components",
"constraint",
"core_number",
"current_flow_betweenness_centrality",
"current_flow_closeness_centrality",
"cycle_basis",
"degree_assortativity_coefficient",
"degree_centrality",
"degree_mixing_matrix",
"degree_pearson_correlation_coefficient",
"diameter",
"dispersion",
"dominating_set",
"eccentricity",
"effective_size",
"eigenvector_centrality",
"estrada_index",
"generalized_degree",
"global_efficiency",
"global_reaching_centrality",
"graph_clique_number",
"harmonic_centrality",
"has_bridges",
"has_eulerian_path",
"hits",
"information_centrality",
"is_at_free",
"is_biconnected",
"is_bipartite",
"is_chordal",
"is_connected",
"is_directed_acyclic_graph",
"is_distance_regular",
"is_eulerian",
"is_forest",
"is_graphical",
"is_multigraphical",
"is_planar",
"is_pseudographical",
"is_regular",
"is_semieulerian",
"is_strongly_regular",
"is_tree",
"is_valid_degree_sequence_erdos_gallai",
"is_valid_degree_sequence_havel_hakimi",
"jaccard_coefficient",
"k_components",
"k_core",
"katz_centrality",
"load_centrality",
"minimum_cycle_basis",
"minimum_edge_cut",
"minimum_node_cut",
"node_clique_number",
"node_connectivity",
"node_degree_xy",
"number_connected_components",
"number_of_cliques",
"number_of_isolates",
"pagerank",
"periphery",
"preferential_attachment",
"reciprocity",
"resource_allocation_index",
"rich_club_coefficient",
"square_clustering",
"stoer_wagner",
"subgraph_centrality",
"subgraph_centrality_exp",
"transitivity",
"triangles",
"voterank",
"wiener_index",
"algebraic_connectivity",
"degree",
"density",
"normalized_laplacian_spectrum",
"number_of_edges",
"number_of_nodes",
"number_of_selfloops",
]
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx, isfunction)
maps = {}
for func in functions_list:
name, f = func
if name in __maps__:
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
"num_nodes_features",
"y",
]
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps
def __call__(self, data):
data_ = data.__copy__()
datahash = hash(data.__repr__)
stats = {}
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, func in v.items():
try:
val = func(_data_)
except:
val = None
if callable(val):
val = val()
if isinstance(val, types.GeneratorType):
try:
val = dict(val)
except:
val = None
if name == "hits":
val = val[0]
if name == "k_components":
for key, value in val.items():
val[key] = list(val[key][0])
stats[name] = self.convert(val)
stats["hash"] = datahash
return stats
def convert(self, val, K=4):
if type(val) == set:
val = list(val)
return val
elif isinstance(val, torch.Tensor):
val = val.cpu().numpy().tolist()
return val
elif isinstance(val, list):
for ind in range(len(val)):
item = val[ind]
if isinstance(item, list):
for subind in range(len(item)):
subitem = item[subind]
if type(subitem) == float:
item[subind] = round(subitem, K)
else:
if type(item) == float:
val[ind] = round(item, K)
return val
elif isinstance(val, dict):
for k, v in val.items():
if isinstance(v, dict):
for k1, v1 in v.items():
if type(v1) == float:
v[k1] = round(v1, K)
if type(v1) == set:
v[k1] = list(v1)
else:
if type(v) == float:
val[k] = round(v, K)
return val
elif isinstance(val, nx.classes.reportviews.DegreeView):
val = np.array(val).tolist()
return val
elif isinstance(val, nx.Graph):
val = nx.node_link_data(val)
return val
elif isinstance(val, np.ndarray):
val = np.around(val, decimals=K)
return val.tolist()
else:
return val

0
tool/__init__.py Normal file
View File

View File

3
tool/explainer/base.py Normal file
View File

@ -0,0 +1,3 @@
class BaseExplaining(object):
def __init__(self,model,explainer_name:wq

View File

@ -0,0 +1 @@
from torch_geometric.nn.models.captum import CaptumModel

View File

View File

View File

@ -1,125 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import multiprocessing as mp
import os
import threading
import time
import types
from inspect import getmembers, isfunction, signature
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
with open("mapping_nx.txt", "r") as file:
BLACK_LIST = [line.rstrip() for line in file]
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx.algorithms, isfunction)
MANUALLY_ADDED = [
"algebraic_connectivity",
"adjacency_spectrum",
"degree",
"density",
"laplacian_spectrum",
"normalized_laplacian_spectrum",
"number_of_selfloops",
"number_of_edges",
"number_of_nodes",
]
MANUALLY_ADDED_LIST = [
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
]
functions_list = functions_list + MANUALLY_ADDED_LIST
maps = {}
for func in functions_list:
name, f = func
if (
name in BLACK_LIST
or name == "recursive_simple_cycles"
or "triad" in name
or "weisfeiler" in name
or "dfs" in name
or "trophic" in name
or "recursive" in name
or "scipy" in name
or "numpy" in name
or "sigma" in name
or "omega" in name
or "all_" in name
):
continue
else:
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
"num_nodes_features",
"y",
]
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps
def __call__(self, data):
data_ = data.__copy__()
stats = {}
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, f in v.items():
if f is None:
stats[name] = None
continue
else:
try:
t0 = time.time()
val = f(_data_)
t1 = time.time()
delta = t1 - t0
except Exception as e:
print(name, e)
with open(f"{name}.txt", "w") as f:
f.write(str(e))
# print(name, round(delta, 4))
# if callable(val) and k == "torch_geometric":
# val = val()
# if isinstance(val, types.GeneratorType):
# val = list(val)
# stats[name] = val
return stats
from torch_geometric.datasets import KarateClub, Planetoid
d = Planetoid(root="/tmp/", name="Cora")
# d = KarateClub()
a = d[0]
st = GraphStat()
stat = st(a)
for k, v in stat.items():
print("---------")
print("Name:", k)
print("Type:", type(v))
print("Val:", v)
print("---------")