diff --git a/explaining_framework/config/explainer_config/xgwt_config.py b/explaining_framework/config/explainer_config/xgwt_config.py new file mode 100644 index 0000000..a62b257 --- /dev/null +++ b/explaining_framework/config/explainer_config/xgwt_config.py @@ -0,0 +1,169 @@ +import functools +import inspect +import logging +import os +import shutil +import warnings +from collections.abc import Iterable +from dataclasses import asdict +from typing import Any + +import torch_geometric.graphgym.register as register +from torch_geometric.data.makedirs import makedirs + +try: # Define global config object + from yacs.config import CfgNode as CN + + xgwt_cfg = CN() +except ImportError: + xgwt_cfg = None + warnings.warn( + "Could not define global config object. Please install " + "'yacs' for using the GraphGym experiment manager via " + "'pip install yacs'." + ) + + +def set_xgwt_cfg(xgwt_cfg): + r""" + This function sets the default config value. + 1) Note that for an experiment, only part of the arguments will be used + The remaining unused arguments won't affect anything. + So feel free to register any argument in graphgym.contrib.config + 2) We support *at most* two levels of configs, e.g., xgwt_cfg.dataset.name + :return: configuration use by the experiment. + """ + if xgwt_cfg is None: + return xgwt_cfg + + xgwt_cfg.wav_approx = False + xgwt_cfg.wav_passband = "heat" + xgwt_cfg.wav_normalization = True + xgwt_cfg.num_candidates = 30 + xgwt_cfg.num_samples = 10 + xgwt_cfg.c_procedure = "auto" + xgwt_cfg.pred_thres_strat = "regular" + xgwt_cfg.CI_threshold = 0.05 + xgwt_cfg.mixing = "uniform" + xgwt_cfg.scales = [3] + xgwt_cfg.pred_thres = 0.1 + xgwt_cfg.incl_prob = 0.4 + xgwt_cfg.top_k = 5 + + +def assert_cfg(xgwt_cfg): + r"""Checks config values, do necessary post processing to the configs + TODO + + """ + # if xgwt_cfg. not in ["node", "edge", "graph", "link_pred"]: + # raise ValueError( + # "Task {} not supported, must be one of node, " + # "edge, graph, link_pred".format(xgwt_cfg.dataset.task) + # ) + # xgwt_cfg.run_dir = xgwt_cfg.out_dir + + +def dump_cfg(xgwt_cfg, path): + r""" + TODO + Dumps the config to the output directory specified in + :obj:`xgwt_cfg.out_dir` + Args: + xgwt_cfg (CfgNode): Configuration node + """ + makedirs(xgwt_cfg.out_dir) + xgwt_cfg_file = os.path.join(xgwt_cfg.out_dir, xgwt_cfg.xgwt_cfg_dest) + with open(xgwt_cfg_file, "w") as f: + xgwt_cfg.dump(stream=f) + + +def load_cfg(xgwt_cfg, args): + r""" + Load configurations from file system and command line + Args: + xgwt_cfg (CfgNode): Configuration node + args (ArgumentParser): Command argument parser + """ + xgwt_cfg.merge_from_file(args.xgwt_cfg_file) + xgwt_cfg.merge_from_list(args.opts) + assert_xgwt_cfg(xgwt_cfg) + + +def makedirs_rm_exist(dir): + if os.path.isdir(dir): + shutil.rmtree(dir) + os.makedirs(dir, exist_ok=True) + + +def get_fname(fname): + r""" + Extract filename from file name path + Args: + fname (string): Filename for the yaml format configuration file + """ + fname = fname.split("/")[-1] + if fname.endswith(".yaml"): + fname = fname[:-5] + elif fname.endswith(".yml"): + fname = fname[:-4] + return fname + + +def set_out_dir(out_dir, fname): + r""" + Create the directory for full experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + fname = get_fname(fname) + xgwt_cfg.out_dir = os.path.join(out_dir, fname) + # Make output directory + if xgwt_cfg.train.auto_resume: + os.makedirs(xgwt_cfg.out_dir, exist_ok=True) + + +def set_run_dir(out_dir): + r""" + Create the directory for each random seed experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`xgwt_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + xgwt_cfg.run_dir = os.path.join(out_dir, str(xgwt_cfg.seed)) + # Make output directory + if xgwt_cfg.train.auto_resume: + os.makedirs(xgwt_cfg.run_dir, exist_ok=True) + + +set_xgwt_cfg(xgwt_cfg) + + +def from_config(func): + if inspect.isclass(func): + params = list(inspect.signature(func.__init__).parameters.values())[1:] + else: + params = list(inspect.signature(func).parameters.values()) + + arg_names = [p.name for p in params] + has_defaults = [p.default != inspect.Parameter.empty for p in params] + + @functools.wraps(func) + def wrapper(*args, xgwt_cfg: Any = None, **kwargs): + if xgwt_cfg is not None: + xgwt_cfg = ( + dict(xgwt_cfg) if isinstance(xgwt_cfg, Iterable) else asdict(xgwt_cfg) + ) + + iterator = zip(arg_names[len(args) :], has_defaults[len(args) :]) + for arg_name, has_default in iterator: + if arg_name in kwargs: + continue + elif arg_name in xgwt_cfg: + kwargs[arg_name] = xgwt_cfg[arg_name] + elif not has_default: + raise ValueError(f"'xgwt_cfg.{arg_name}' undefined") + return func(*args, **kwargs) + + return wrapper