Adding some graph stats

This commit is contained in:
araison 2022-12-06 01:04:27 +01:00
parent 68c103db4d
commit 9295e91181
1 changed files with 90 additions and 0 deletions

90
utils/stat.py Normal file
View File

@ -0,0 +1,90 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import os
import threading
from inspect import getmembers, isfunction, signature
# import custom_graphgym # noqa, register custom modules
import networkx as nx
import pandas as pd
from docstring_parser import parse
from torch_geometric.data import Data
# from torch_geometric.explain import Explanation
from torch_geometric.utils import to_networkx
GRAPH_STAT_TYPE = ["int", "float", "bool", "boolean", "dict", "dictionary"]
class GraphStat(object):
def __init__(self):
self.stat = {}
self.maps = {
"networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(),
}
def available_map_networkx(self):
functions_list = getmembers(nx.algorithms, isfunction)
maps = {}
for func in functions_list:
name, f = func
if "all_" in name:
continue
docstring = parse(f.__doc__)
try:
# rt = docstring.returns.type_name
# if rt in GRAPH_STAT_TYPE:
maps[name] = f
except AttributeError:
continue
return maps
def available_map_torch_geometric(self):
names = [
"num_nodes",
"num_edges",
"has_self_loops",
"has_isolated_nodes",
# "num_nodes_features",
"y",
]
maps = {name:lambda x,name=name: x.__getattr__(name) for name in names}
return maps
def to_series(self, name, val):
self.stat.append(pd.Series(data={name: val}))
def __call__(self, data):
data_ = data.__copy__()
self.stat = []
process = []
for k, v in self.maps.items():
if k == "networkx":
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
for name, f in v.items():
try:
proc = f(_data_)
if callable(proc) and k == "torch_geometric":
proc = proc()
self.to_series(name, proc)
except:
continue
return self.stat
from torch_geometric.datasets import KarateClub
d = KarateClub()
a = d[0]
st = GraphStat()
stat = st(a)
for item in stat:
if item.dtypes == 'int' or item.dtypes == 'float':
continue
else:
print(item)