Adding XGWT method
This commit is contained in:
parent
3f4839fa59
commit
27e8a8a4d8
|
@ -0,0 +1,169 @@
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch_geometric.graphgym.register as register
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
|
|
||||||
|
try: # Define global config object
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
xgwt_cfg = CN()
|
||||||
|
except ImportError:
|
||||||
|
xgwt_cfg = None
|
||||||
|
warnings.warn(
|
||||||
|
"Could not define global config object. Please install "
|
||||||
|
"'yacs' for using the GraphGym experiment manager via "
|
||||||
|
"'pip install yacs'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_xgwt_cfg(xgwt_cfg):
|
||||||
|
r"""
|
||||||
|
This function sets the default config value.
|
||||||
|
1) Note that for an experiment, only part of the arguments will be used
|
||||||
|
The remaining unused arguments won't affect anything.
|
||||||
|
So feel free to register any argument in graphgym.contrib.config
|
||||||
|
2) We support *at most* two levels of configs, e.g., xgwt_cfg.dataset.name
|
||||||
|
:return: configuration use by the experiment.
|
||||||
|
"""
|
||||||
|
if xgwt_cfg is None:
|
||||||
|
return xgwt_cfg
|
||||||
|
|
||||||
|
xgwt_cfg.wav_approx = False
|
||||||
|
xgwt_cfg.wav_passband = "heat"
|
||||||
|
xgwt_cfg.wav_normalization = True
|
||||||
|
xgwt_cfg.num_candidates = 30
|
||||||
|
xgwt_cfg.num_samples = 10
|
||||||
|
xgwt_cfg.c_procedure = "auto"
|
||||||
|
xgwt_cfg.pred_thres_strat = "regular"
|
||||||
|
xgwt_cfg.CI_threshold = 0.05
|
||||||
|
xgwt_cfg.mixing = "uniform"
|
||||||
|
xgwt_cfg.scales = [3]
|
||||||
|
xgwt_cfg.pred_thres = 0.1
|
||||||
|
xgwt_cfg.incl_prob = 0.4
|
||||||
|
xgwt_cfg.top_k = 5
|
||||||
|
|
||||||
|
|
||||||
|
def assert_cfg(xgwt_cfg):
|
||||||
|
r"""Checks config values, do necessary post processing to the configs
|
||||||
|
TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
# if xgwt_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||||
|
# raise ValueError(
|
||||||
|
# "Task {} not supported, must be one of node, "
|
||||||
|
# "edge, graph, link_pred".format(xgwt_cfg.dataset.task)
|
||||||
|
# )
|
||||||
|
# xgwt_cfg.run_dir = xgwt_cfg.out_dir
|
||||||
|
|
||||||
|
|
||||||
|
def dump_cfg(xgwt_cfg, path):
|
||||||
|
r"""
|
||||||
|
TODO
|
||||||
|
Dumps the config to the output directory specified in
|
||||||
|
:obj:`xgwt_cfg.out_dir`
|
||||||
|
Args:
|
||||||
|
xgwt_cfg (CfgNode): Configuration node
|
||||||
|
"""
|
||||||
|
makedirs(xgwt_cfg.out_dir)
|
||||||
|
xgwt_cfg_file = os.path.join(xgwt_cfg.out_dir, xgwt_cfg.xgwt_cfg_dest)
|
||||||
|
with open(xgwt_cfg_file, "w") as f:
|
||||||
|
xgwt_cfg.dump(stream=f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_cfg(xgwt_cfg, args):
|
||||||
|
r"""
|
||||||
|
Load configurations from file system and command line
|
||||||
|
Args:
|
||||||
|
xgwt_cfg (CfgNode): Configuration node
|
||||||
|
args (ArgumentParser): Command argument parser
|
||||||
|
"""
|
||||||
|
xgwt_cfg.merge_from_file(args.xgwt_cfg_file)
|
||||||
|
xgwt_cfg.merge_from_list(args.opts)
|
||||||
|
assert_xgwt_cfg(xgwt_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def makedirs_rm_exist(dir):
|
||||||
|
if os.path.isdir(dir):
|
||||||
|
shutil.rmtree(dir)
|
||||||
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_fname(fname):
|
||||||
|
r"""
|
||||||
|
Extract filename from file name path
|
||||||
|
Args:
|
||||||
|
fname (string): Filename for the yaml format configuration file
|
||||||
|
"""
|
||||||
|
fname = fname.split("/")[-1]
|
||||||
|
if fname.endswith(".yaml"):
|
||||||
|
fname = fname[:-5]
|
||||||
|
elif fname.endswith(".yml"):
|
||||||
|
fname = fname[:-4]
|
||||||
|
return fname
|
||||||
|
|
||||||
|
|
||||||
|
def set_out_dir(out_dir, fname):
|
||||||
|
r"""
|
||||||
|
Create the directory for full experiment run
|
||||||
|
Args:
|
||||||
|
out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir`
|
||||||
|
fname (string): Filename for the yaml format configuration file
|
||||||
|
"""
|
||||||
|
fname = get_fname(fname)
|
||||||
|
xgwt_cfg.out_dir = os.path.join(out_dir, fname)
|
||||||
|
# Make output directory
|
||||||
|
if xgwt_cfg.train.auto_resume:
|
||||||
|
os.makedirs(xgwt_cfg.out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def set_run_dir(out_dir):
|
||||||
|
r"""
|
||||||
|
Create the directory for each random seed experiment run
|
||||||
|
Args:
|
||||||
|
out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir`
|
||||||
|
fname (string): Filename for the yaml format configuration file
|
||||||
|
"""
|
||||||
|
xgwt_cfg.run_dir = os.path.join(out_dir, str(xgwt_cfg.seed))
|
||||||
|
# Make output directory
|
||||||
|
if xgwt_cfg.train.auto_resume:
|
||||||
|
os.makedirs(xgwt_cfg.run_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
set_xgwt_cfg(xgwt_cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def from_config(func):
|
||||||
|
if inspect.isclass(func):
|
||||||
|
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||||
|
else:
|
||||||
|
params = list(inspect.signature(func).parameters.values())
|
||||||
|
|
||||||
|
arg_names = [p.name for p in params]
|
||||||
|
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, xgwt_cfg: Any = None, **kwargs):
|
||||||
|
if xgwt_cfg is not None:
|
||||||
|
xgwt_cfg = (
|
||||||
|
dict(xgwt_cfg) if isinstance(xgwt_cfg, Iterable) else asdict(xgwt_cfg)
|
||||||
|
)
|
||||||
|
|
||||||
|
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||||
|
for arg_name, has_default in iterator:
|
||||||
|
if arg_name in kwargs:
|
||||||
|
continue
|
||||||
|
elif arg_name in xgwt_cfg:
|
||||||
|
kwargs[arg_name] = xgwt_cfg[arg_name]
|
||||||
|
elif not has_default:
|
||||||
|
raise ValueError(f"'xgwt_cfg.{arg_name}' undefined")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
Loading…
Reference in New Issue