explaining_framework/utils/stat.py

126 lines
3.7 KiB
Python
Raw Normal View History

2022-12-06 00:04:27 +00:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
2022-12-06 17:25:54 +00:00
import multiprocessing as mp
2022-12-06 00:04:27 +00:00
import os
import threading
2022-12-06 17:25:54 +00:00
import time
import types
2022-12-06 00:04:27 +00:00
from inspect import getmembers, isfunction, signature
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
2022-12-06 17:25:54 +00:00
with open("mapping_nx.txt", "r") as file:
BLACK_LIST = [line.rstrip() for line in file]
2022-12-06 00:04:27 +00:00
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx.algorithms, isfunction)
2022-12-06 17:25:54 +00:00
MANUALLY_ADDED = [
"algebraic_connectivity",
"adjacency_spectrum",
"degree",
"density",
"laplacian_spectrum",
"normalized_laplacian_spectrum",
"number_of_selfloops",
"number_of_edges",
"number_of_nodes",
]
MANUALLY_ADDED_LIST = [
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
]
functions_list = functions_list + MANUALLY_ADDED_LIST
2022-12-06 00:04:27 +00:00
maps = {}
for func in functions_list:
name, f = func
2022-12-06 17:25:54 +00:00
if (
name in BLACK_LIST
or name == "recursive_simple_cycles"
or "triad" in name
or "weisfeiler" in name
or "dfs" in name
or "trophic" in name
or "recursive" in name
or "scipy" in name
or "numpy" in name
or "sigma" in name
or "omega" in name
or "all_" in name
):
2022-12-06 00:04:27 +00:00
continue
2022-12-06 17:25:54 +00:00
else:
2022-12-06 00:04:27 +00:00
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
2022-12-06 17:25:54 +00:00
"num_nodes_features",
2022-12-06 00:04:27 +00:00
"y",
]
2022-12-06 17:25:54 +00:00
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
2022-12-06 00:04:27 +00:00
return maps
def __call__(self, data):
data_ = data.__copy__()
2022-12-06 17:25:54 +00:00
stats = {}
2022-12-06 00:04:27 +00:00
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, f in v.items():
2022-12-06 17:25:54 +00:00
if f is None:
stats[name] = None
2022-12-06 00:04:27 +00:00
continue
2022-12-06 17:25:54 +00:00
else:
try:
t0 = time.time()
val = f(_data_)
t1 = time.time()
delta = t1 - t0
except Exception as e:
print(name, e)
with open(f"{name}.txt", "w") as f:
f.write(str(e))
# print(name, round(delta, 4))
# if callable(val) and k == "torch_geometric":
# val = val()
# if isinstance(val, types.GeneratorType):
# val = list(val)
# stats[name] = val
return stats
2022-12-06 00:04:27 +00:00
2022-12-06 17:25:54 +00:00
from torch_geometric.datasets import KarateClub, Planetoid
2022-12-06 00:04:27 +00:00
2022-12-06 17:25:54 +00:00
d = Planetoid(root="/tmp/", name="Cora")
# d = KarateClub()
2022-12-06 00:04:27 +00:00
a = d[0]
st = GraphStat()
stat = st(a)
2022-12-06 17:25:54 +00:00
for k, v in stat.items():
print("---------")
print("Name:", k)
print("Type:", type(v))
print("Val:", v)
print("---------")