91 lines
2.5 KiB
Python
91 lines
2.5 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
import threading
|
||
|
from inspect import getmembers, isfunction, signature
|
||
|
|
||
|
# import custom_graphgym # noqa, register custom modules
|
||
|
import networkx as nx
|
||
|
import pandas as pd
|
||
|
from docstring_parser import parse
|
||
|
from torch_geometric.data import Data
|
||
|
# from torch_geometric.explain import Explanation
|
||
|
from torch_geometric.utils import to_networkx
|
||
|
|
||
|
GRAPH_STAT_TYPE = ["int", "float", "bool", "boolean", "dict", "dictionary"]
|
||
|
|
||
|
|
||
|
class GraphStat(object):
|
||
|
def __init__(self):
|
||
|
self.stat = {}
|
||
|
self.maps = {
|
||
|
"networkx": self.available_map_networkx(),
|
||
|
"torch_geometric": self.available_map_torch_geometric(),
|
||
|
}
|
||
|
|
||
|
def available_map_networkx(self):
|
||
|
functions_list = getmembers(nx.algorithms, isfunction)
|
||
|
maps = {}
|
||
|
for func in functions_list:
|
||
|
name, f = func
|
||
|
if "all_" in name:
|
||
|
continue
|
||
|
docstring = parse(f.__doc__)
|
||
|
try:
|
||
|
# rt = docstring.returns.type_name
|
||
|
# if rt in GRAPH_STAT_TYPE:
|
||
|
maps[name] = f
|
||
|
except AttributeError:
|
||
|
continue
|
||
|
return maps
|
||
|
|
||
|
def available_map_torch_geometric(self):
|
||
|
names = [
|
||
|
"num_nodes",
|
||
|
"num_edges",
|
||
|
"has_self_loops",
|
||
|
"has_isolated_nodes",
|
||
|
# "num_nodes_features",
|
||
|
"y",
|
||
|
]
|
||
|
maps = {name:lambda x,name=name: x.__getattr__(name) for name in names}
|
||
|
return maps
|
||
|
|
||
|
def to_series(self, name, val):
|
||
|
self.stat.append(pd.Series(data={name: val}))
|
||
|
|
||
|
def __call__(self, data):
|
||
|
data_ = data.__copy__()
|
||
|
self.stat = []
|
||
|
process = []
|
||
|
for k, v in self.maps.items():
|
||
|
if k == "networkx":
|
||
|
_data_ = to_networkx(data)
|
||
|
_data_ = _data_.to_undirected()
|
||
|
elif k == "torch_geometric":
|
||
|
_data_ = data.__copy__()
|
||
|
for name, f in v.items():
|
||
|
try:
|
||
|
proc = f(_data_)
|
||
|
if callable(proc) and k == "torch_geometric":
|
||
|
proc = proc()
|
||
|
self.to_series(name, proc)
|
||
|
except:
|
||
|
continue
|
||
|
return self.stat
|
||
|
|
||
|
|
||
|
from torch_geometric.datasets import KarateClub
|
||
|
|
||
|
d = KarateClub()
|
||
|
a = d[0]
|
||
|
st = GraphStat()
|
||
|
stat = st(a)
|
||
|
for item in stat:
|
||
|
if item.dtypes == 'int' or item.dtypes == 'float':
|
||
|
continue
|
||
|
else:
|
||
|
print(item)
|