Adding XGWT method

This commit is contained in:
araison 2023-02-12 12:58:22 +01:00
parent 3f4839fa59
commit 27e8a8a4d8

View File

@ -0,0 +1,169 @@
import functools
import inspect
import logging
import os
import shutil
import warnings
from collections.abc import Iterable
from dataclasses import asdict
from typing import Any
import torch_geometric.graphgym.register as register
from torch_geometric.data.makedirs import makedirs
try: # Define global config object
from yacs.config import CfgNode as CN
xgwt_cfg = CN()
except ImportError:
xgwt_cfg = None
warnings.warn(
"Could not define global config object. Please install "
"'yacs' for using the GraphGym experiment manager via "
"'pip install yacs'."
)
def set_xgwt_cfg(xgwt_cfg):
r"""
This function sets the default config value.
1) Note that for an experiment, only part of the arguments will be used
The remaining unused arguments won't affect anything.
So feel free to register any argument in graphgym.contrib.config
2) We support *at most* two levels of configs, e.g., xgwt_cfg.dataset.name
:return: configuration use by the experiment.
"""
if xgwt_cfg is None:
return xgwt_cfg
xgwt_cfg.wav_approx = False
xgwt_cfg.wav_passband = "heat"
xgwt_cfg.wav_normalization = True
xgwt_cfg.num_candidates = 30
xgwt_cfg.num_samples = 10
xgwt_cfg.c_procedure = "auto"
xgwt_cfg.pred_thres_strat = "regular"
xgwt_cfg.CI_threshold = 0.05
xgwt_cfg.mixing = "uniform"
xgwt_cfg.scales = [3]
xgwt_cfg.pred_thres = 0.1
xgwt_cfg.incl_prob = 0.4
xgwt_cfg.top_k = 5
def assert_cfg(xgwt_cfg):
r"""Checks config values, do necessary post processing to the configs
TODO
"""
# if xgwt_cfg. not in ["node", "edge", "graph", "link_pred"]:
# raise ValueError(
# "Task {} not supported, must be one of node, "
# "edge, graph, link_pred".format(xgwt_cfg.dataset.task)
# )
# xgwt_cfg.run_dir = xgwt_cfg.out_dir
def dump_cfg(xgwt_cfg, path):
r"""
TODO
Dumps the config to the output directory specified in
:obj:`xgwt_cfg.out_dir`
Args:
xgwt_cfg (CfgNode): Configuration node
"""
makedirs(xgwt_cfg.out_dir)
xgwt_cfg_file = os.path.join(xgwt_cfg.out_dir, xgwt_cfg.xgwt_cfg_dest)
with open(xgwt_cfg_file, "w") as f:
xgwt_cfg.dump(stream=f)
def load_cfg(xgwt_cfg, args):
r"""
Load configurations from file system and command line
Args:
xgwt_cfg (CfgNode): Configuration node
args (ArgumentParser): Command argument parser
"""
xgwt_cfg.merge_from_file(args.xgwt_cfg_file)
xgwt_cfg.merge_from_list(args.opts)
assert_xgwt_cfg(xgwt_cfg)
def makedirs_rm_exist(dir):
if os.path.isdir(dir):
shutil.rmtree(dir)
os.makedirs(dir, exist_ok=True)
def get_fname(fname):
r"""
Extract filename from file name path
Args:
fname (string): Filename for the yaml format configuration file
"""
fname = fname.split("/")[-1]
if fname.endswith(".yaml"):
fname = fname[:-5]
elif fname.endswith(".yml"):
fname = fname[:-4]
return fname
def set_out_dir(out_dir, fname):
r"""
Create the directory for full experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
fname = get_fname(fname)
xgwt_cfg.out_dir = os.path.join(out_dir, fname)
# Make output directory
if xgwt_cfg.train.auto_resume:
os.makedirs(xgwt_cfg.out_dir, exist_ok=True)
def set_run_dir(out_dir):
r"""
Create the directory for each random seed experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
xgwt_cfg.run_dir = os.path.join(out_dir, str(xgwt_cfg.seed))
# Make output directory
if xgwt_cfg.train.auto_resume:
os.makedirs(xgwt_cfg.run_dir, exist_ok=True)
set_xgwt_cfg(xgwt_cfg)
def from_config(func):
if inspect.isclass(func):
params = list(inspect.signature(func.__init__).parameters.values())[1:]
else:
params = list(inspect.signature(func).parameters.values())
arg_names = [p.name for p in params]
has_defaults = [p.default != inspect.Parameter.empty for p in params]
@functools.wraps(func)
def wrapper(*args, xgwt_cfg: Any = None, **kwargs):
if xgwt_cfg is not None:
xgwt_cfg = (
dict(xgwt_cfg) if isinstance(xgwt_cfg, Iterable) else asdict(xgwt_cfg)
)
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
for arg_name, has_default in iterator:
if arg_name in kwargs:
continue
elif arg_name in xgwt_cfg:
kwargs[arg_name] = xgwt_cfg[arg_name]
elif not has_default:
raise ValueError(f"'xgwt_cfg.{arg_name}' undefined")
return func(*args, **kwargs)
return wrapper