Fixing bugs and adding new features

This commit is contained in:
araison 2022-12-06 18:25:54 +01:00
parent 9295e91181
commit 09a20b891f
1 changed files with 69 additions and 34 deletions

View File

@ -2,24 +2,23 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import multiprocessing as mp
import os import os
import threading import threading
import time
import types
from inspect import getmembers, isfunction, signature from inspect import getmembers, isfunction, signature
# import custom_graphgym # noqa, register custom modules
import networkx as nx import networkx as nx
import pandas as pd
from docstring_parser import parse
from torch_geometric.data import Data from torch_geometric.data import Data
# from torch_geometric.explain import Explanation
from torch_geometric.utils import to_networkx from torch_geometric.utils import to_networkx
GRAPH_STAT_TYPE = ["int", "float", "bool", "boolean", "dict", "dictionary"] with open("mapping_nx.txt", "r") as file:
BLACK_LIST = [line.rstrip() for line in file]
class GraphStat(object): class GraphStat(object):
def __init__(self): def __init__(self):
self.stat = {}
self.maps = { self.maps = {
"networkx": self.available_map_networkx(), "networkx": self.available_map_networkx(),
"torch_geometric": self.available_map_torch_geometric(), "torch_geometric": self.available_map_torch_geometric(),
@ -27,18 +26,41 @@ class GraphStat(object):
def available_map_networkx(self): def available_map_networkx(self):
functions_list = getmembers(nx.algorithms, isfunction) functions_list = getmembers(nx.algorithms, isfunction)
MANUALLY_ADDED = [
"algebraic_connectivity",
"adjacency_spectrum",
"degree",
"density",
"laplacian_spectrum",
"normalized_laplacian_spectrum",
"number_of_selfloops",
"number_of_edges",
"number_of_nodes",
]
MANUALLY_ADDED_LIST = [
item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED
]
functions_list = functions_list + MANUALLY_ADDED_LIST
maps = {} maps = {}
for func in functions_list: for func in functions_list:
name, f = func name, f = func
if "all_" in name: if (
name in BLACK_LIST
or name == "recursive_simple_cycles"
or "triad" in name
or "weisfeiler" in name
or "dfs" in name
or "trophic" in name
or "recursive" in name
or "scipy" in name
or "numpy" in name
or "sigma" in name
or "omega" in name
or "all_" in name
):
continue continue
docstring = parse(f.__doc__) else:
try:
# rt = docstring.returns.type_name
# if rt in GRAPH_STAT_TYPE:
maps[name] = f maps[name] = f
except AttributeError:
continue
return maps return maps
def available_map_torch_geometric(self): def available_map_torch_geometric(self):
@ -47,19 +69,18 @@ class GraphStat(object):
"num_edges", "num_edges",
"has_self_loops", "has_self_loops",
"has_isolated_nodes", "has_isolated_nodes",
# "num_nodes_features", "num_nodes_features",
"y", "y",
] ]
maps = {name:lambda x,name=name: x.__getattr__(name) for name in names} maps = {
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
for name in names
}
return maps return maps
def to_series(self, name, val):
self.stat.append(pd.Series(data={name: val}))
def __call__(self, data): def __call__(self, data):
data_ = data.__copy__() data_ = data.__copy__()
self.stat = [] stats = {}
process = []
for k, v in self.maps.items(): for k, v in self.maps.items():
if k == "networkx": if k == "networkx":
_data_ = to_networkx(data) _data_ = to_networkx(data)
@ -67,24 +88,38 @@ class GraphStat(object):
elif k == "torch_geometric": elif k == "torch_geometric":
_data_ = data.__copy__() _data_ = data.__copy__()
for name, f in v.items(): for name, f in v.items():
try: if f is None:
proc = f(_data_) stats[name] = None
if callable(proc) and k == "torch_geometric":
proc = proc()
self.to_series(name, proc)
except:
continue continue
return self.stat else:
try:
t0 = time.time()
val = f(_data_)
t1 = time.time()
delta = t1 - t0
except Exception as e:
print(name, e)
with open(f"{name}.txt", "w") as f:
f.write(str(e))
# print(name, round(delta, 4))
# if callable(val) and k == "torch_geometric":
# val = val()
# if isinstance(val, types.GeneratorType):
# val = list(val)
# stats[name] = val
return stats
from torch_geometric.datasets import KarateClub from torch_geometric.datasets import KarateClub, Planetoid
d = KarateClub() d = Planetoid(root="/tmp/", name="Cora")
# d = KarateClub()
a = d[0] a = d[0]
st = GraphStat() st = GraphStat()
stat = st(a) stat = st(a)
for item in stat: for k, v in stat.items():
if item.dtypes == 'int' or item.dtypes == 'float': print("---------")
continue print("Name:", k)
else: print("Type:", type(v))
print(item) print("Val:", v)
print("---------")