explaining_framework/utils/stat.py

126 lines
3.7 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import multiprocessing as mp
import os
import threading
import time
import types
from inspect import getmembers, isfunction, signature
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
with open("mapping_nx.txt", "r") as file:
BLACK_LIST = [line.rstrip() for line in file]
class GraphStat(object):
def __init__(self):
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx.algorithms, isfunction)
MANUALLY_ADDED = [
"algebraic_connectivity",
"adjacency_spectrum",
"degree",
"density",
"laplacian_spectrum",
"normalized_laplacian_spectrum",
"number_of_selfloops",
"number_of_edges",
"number_of_nodes",
]
MANUALLY_ADDED_LIST = [
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
]
functions_list = functions_list + MANUALLY_ADDED_LIST
maps = {}
for func in functions_list:
name, f = func
if (
name in BLACK_LIST
or name == "recursive_simple_cycles"
or "triad" in name
or "weisfeiler" in name
or "dfs" in name
or "trophic" in name
or "recursive" in name
or "scipy" in name
or "numpy" in name
or "sigma" in name
or "omega" in name
or "all_" in name
):
continue
else:
maps[name] = f
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
"num_nodes_features",
"y",
]
maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps
def __call__(self, data):
data_ = data.__copy__()
stats = {}
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, f in v.items():
if f is None:
stats[name] = None
continue
else:
try:
t0 = time.time()
val = f(_data_)
t1 = time.time()
delta = t1 - t0
except Exception as e:
print(name, e)
with open(f"{name}.txt", "w") as f:
f.write(str(e))
# print(name, round(delta, 4))
# if callable(val) and k == "torch_geometric":
# val = val()
# if isinstance(val, types.GeneratorType):
# val = list(val)
# stats[name] = val
return stats
from torch_geometric.datasets import KarateClub, Planetoid
d = Planetoid(root="/tmp/", name="Cora")
# d = KarateClub()
a = d[0]
st = GraphStat()
stat = st(a)
for k, v in stat.items():
print("---------")
print("Name:", k)
print("Type:", type(v))
print("Val:", v)
print("---------")