Adding XGWT method
This commit is contained in:
parent
3f4839fa59
commit
27e8a8a4d8
169
explaining_framework/config/explainer_config/xgwt_config.py
Normal file
169
explaining_framework/config/explainer_config/xgwt_config.py
Normal file
@ -0,0 +1,169 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import torch_geometric.graphgym.register as register
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
|
||||
try: # Define global config object
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
xgwt_cfg = CN()
|
||||
except ImportError:
|
||||
xgwt_cfg = None
|
||||
warnings.warn(
|
||||
"Could not define global config object. Please install "
|
||||
"'yacs' for using the GraphGym experiment manager via "
|
||||
"'pip install yacs'."
|
||||
)
|
||||
|
||||
|
||||
def set_xgwt_cfg(xgwt_cfg):
|
||||
r"""
|
||||
This function sets the default config value.
|
||||
1) Note that for an experiment, only part of the arguments will be used
|
||||
The remaining unused arguments won't affect anything.
|
||||
So feel free to register any argument in graphgym.contrib.config
|
||||
2) We support *at most* two levels of configs, e.g., xgwt_cfg.dataset.name
|
||||
:return: configuration use by the experiment.
|
||||
"""
|
||||
if xgwt_cfg is None:
|
||||
return xgwt_cfg
|
||||
|
||||
xgwt_cfg.wav_approx = False
|
||||
xgwt_cfg.wav_passband = "heat"
|
||||
xgwt_cfg.wav_normalization = True
|
||||
xgwt_cfg.num_candidates = 30
|
||||
xgwt_cfg.num_samples = 10
|
||||
xgwt_cfg.c_procedure = "auto"
|
||||
xgwt_cfg.pred_thres_strat = "regular"
|
||||
xgwt_cfg.CI_threshold = 0.05
|
||||
xgwt_cfg.mixing = "uniform"
|
||||
xgwt_cfg.scales = [3]
|
||||
xgwt_cfg.pred_thres = 0.1
|
||||
xgwt_cfg.incl_prob = 0.4
|
||||
xgwt_cfg.top_k = 5
|
||||
|
||||
|
||||
def assert_cfg(xgwt_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs
|
||||
TODO
|
||||
|
||||
"""
|
||||
# if xgwt_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||
# raise ValueError(
|
||||
# "Task {} not supported, must be one of node, "
|
||||
# "edge, graph, link_pred".format(xgwt_cfg.dataset.task)
|
||||
# )
|
||||
# xgwt_cfg.run_dir = xgwt_cfg.out_dir
|
||||
|
||||
|
||||
def dump_cfg(xgwt_cfg, path):
|
||||
r"""
|
||||
TODO
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`xgwt_cfg.out_dir`
|
||||
Args:
|
||||
xgwt_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(xgwt_cfg.out_dir)
|
||||
xgwt_cfg_file = os.path.join(xgwt_cfg.out_dir, xgwt_cfg.xgwt_cfg_dest)
|
||||
with open(xgwt_cfg_file, "w") as f:
|
||||
xgwt_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_cfg(xgwt_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
xgwt_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
xgwt_cfg.merge_from_file(args.xgwt_cfg_file)
|
||||
xgwt_cfg.merge_from_list(args.opts)
|
||||
assert_xgwt_cfg(xgwt_cfg)
|
||||
|
||||
|
||||
def makedirs_rm_exist(dir):
|
||||
if os.path.isdir(dir):
|
||||
shutil.rmtree(dir)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
|
||||
def get_fname(fname):
|
||||
r"""
|
||||
Extract filename from file name path
|
||||
Args:
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = fname.split("/")[-1]
|
||||
if fname.endswith(".yaml"):
|
||||
fname = fname[:-5]
|
||||
elif fname.endswith(".yml"):
|
||||
fname = fname[:-4]
|
||||
return fname
|
||||
|
||||
|
||||
def set_out_dir(out_dir, fname):
|
||||
r"""
|
||||
Create the directory for full experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = get_fname(fname)
|
||||
xgwt_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
if xgwt_cfg.train.auto_resume:
|
||||
os.makedirs(xgwt_cfg.out_dir, exist_ok=True)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
r"""
|
||||
Create the directory for each random seed experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
xgwt_cfg.run_dir = os.path.join(out_dir, str(xgwt_cfg.seed))
|
||||
# Make output directory
|
||||
if xgwt_cfg.train.auto_resume:
|
||||
os.makedirs(xgwt_cfg.run_dir, exist_ok=True)
|
||||
|
||||
|
||||
set_xgwt_cfg(xgwt_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
if inspect.isclass(func):
|
||||
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||
else:
|
||||
params = list(inspect.signature(func).parameters.values())
|
||||
|
||||
arg_names = [p.name for p in params]
|
||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, xgwt_cfg: Any = None, **kwargs):
|
||||
if xgwt_cfg is not None:
|
||||
xgwt_cfg = (
|
||||
dict(xgwt_cfg) if isinstance(xgwt_cfg, Iterable) else asdict(xgwt_cfg)
|
||||
)
|
||||
|
||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
for arg_name, has_default in iterator:
|
||||
if arg_name in kwargs:
|
||||
continue
|
||||
elif arg_name in xgwt_cfg:
|
||||
kwargs[arg_name] = xgwt_cfg[arg_name]
|
||||
elif not has_default:
|
||||
raise ValueError(f"'xgwt_cfg.{arg_name}' undefined")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
Loading…
Reference in New Issue
Block a user