import functools import inspect import logging import os import shutil import warnings from collections.abc import Iterable from dataclasses import asdict from typing import Any import torch_geometric.graphgym.register as register from torch_geometric.data.makedirs import makedirs try: # Define global config object from yacs.config import CfgNode as CN explaining_cfg = CN() except ImportError: explaining_cfg = None warnings.warn( "Could not define global config object. Please install " "'yacs' for using the GraphGym experiment manager via " "'pip install yacs'." ) def set_cfg(explaining_cfg): r""" This function sets the default config value. 1) Note that for an experiment, only part of the arguments will be used The remaining unused arguments won't affect anything. So feel free to register any argument in graphgym.contrib.config 2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name :return: configuration use by the experiment. """ if explaining_cfg is None: return explaining_cfg # ----------------------------------------------------------------------- # # Basic options # ----------------------------------------------------------------------- # # Set print destination: stdout / file / both explaining_cfg.print = "both" explaining_cfg.out_dir = "./explanations" explaining_cfg.cfg_dest = "explaining_config.yaml" explaining_cfg.seed = 0 # ----------------------------------------------------------------------- # # Dataset options # ----------------------------------------------------------------------- # explaining_cfg.dataset = CN() explaining_cfg.dataset.name = "Cora" explaining_cfg.dataset.specific_items = None explaining_cfg.run_topological_stat = True # ----------------------------------------------------------------------- # # Model options # ----------------------------------------------------------------------- # explaining_cfg.model = CN() # Set wether or not load the best model for given dataset or a path explaining_cfg.model.ckpt = "best" # Setting the path of models folder explaining_cfg.model.path = "path" # ----------------------------------------------------------------------- # # Explainer options # ----------------------------------------------------------------------- # explaining_cfg.explainer = CN() # Name of the explaining method explaining_cfg.explainer.name = "EiXGNN" # Whether or not to provide specific explaining methods configuration or default configuration explaining_cfg.explainer.cfg = "default" # Whether or not recomputing explanation if they already exist explaining_cfg.explainer.force = False # ----------------------------------------------------------------------- # # Explaining options # ----------------------------------------------------------------------- # # 'ExplanationType : 'model' or 'phenomenon' explaining_cfg.explanation_type = "model" explaining_cfg.model_config = CN() # Do not modify it, will be handled by dataset , assuming one dataset = one learning task explaining_cfg.model_config.mode = None # Do not modify it, will be handled by dataset , assuming one dataset = one learning task explaining_cfg.model_config.task_level = None # Do not modify it, we always assume here that model output are 'raw' explaining_cfg.model_config.return_type = "raw" # ----------------------------------------------------------------------- # # Thresholding options # ----------------------------------------------------------------------- # explaining_cfg.threshold_config = CN() explaining_cfg.threshold_config.threshold_type = None explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7] explaining_cfg.threshold_config.relu_and_normalize = True # Select device: 'cpu', 'cuda', 'auto' explaining_cfg.accelerator = "auto" # which objectives metrics to computes, either all or one in particular if implemented explaining_cfg.metrics = "all" # Whether or not recomputing metrics if they already exist explaining_cfg.metrics.force = False def assert_cfg(explaining_cfg): r"""Checks config values, do necessary post processing to the configs""" explaining_cfg.run_dir = explaining_cfg.out_dir def dump_cfg(explaining_cfg): r""" Dumps the config to the output directory specified in :obj:`explaining_cfg.out_dir` Args: explaining_cfg (CfgNode): Configuration node """ makedirs(explaining_cfg.out_dir) explaining_cfg_file = os.path.join( explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest ) with open(explaining_cfg_file, "w") as f: explaining_cfg.dump(stream=f) def load_cfg(explaining_cfg, args): r""" Load configurations from file system and command line Args: explaining_cfg (CfgNode): Configuration node args (ArgumentParser): Command argument parser """ explaining_cfg.merge_from_file(args.explaining_cfg_file) explaining_cfg.merge_from_list(args.opts) assert_explaining_cfg(explaining_cfg) def makedirs_rm_exist(dir): if os.path.isdir(dir): shutil.rmtree(dir) os.makedirs(dir, exist_ok=True) def get_fname(fname): r""" Extract filename from file name path Args: fname (string): Filename for the yaml format configuration file """ fname = fname.split("/")[-1] if fname.endswith(".yaml"): fname = fname[:-5] elif fname.endswith(".yml"): fname = fname[:-4] return fname def set_out_dir(out_dir, fname): r""" Create the directory for full experiment run Args: out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir` fname (string): Filename for the yaml format configuration file """ fname = get_fname(fname) explaining_cfg.out_dir = os.path.join(out_dir, fname) # Make output directory if explaining_cfg.train.auto_resume: os.makedirs(explaining_cfg.out_dir, exist_ok=True) def set_run_dir(out_dir): r""" Create the directory for each random seed experiment run Args: out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir` fname (string): Filename for the yaml format configuration file """ explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed)) # Make output directory if explaining_cfg.train.auto_resume: os.makedirs(explaining_cfg.run_dir, exist_ok=True) set_cfg(explaining_cfg) def from_config(func): if inspect.isclass(func): params = list(inspect.signature(func.__init__).parameters.values())[1:] else: params = list(inspect.signature(func).parameters.values()) arg_names = [p.name for p in params] has_defaults = [p.default != inspect.Parameter.empty for p in params] @functools.wraps(func) def wrapper(*args, explaining_cfg: Any = None, **kwargs): if explaining_cfg is not None: explaining_cfg = ( dict(explaining_cfg) if isinstance(explaining_cfg, Iterable) else asdict(explaining_cfg) ) iterator = zip(arg_names[len(args) :], has_defaults[len(args) :]) for arg_name, has_default in iterator: if arg_name in kwargs: continue elif arg_name in explaining_cfg: kwargs[arg_name] = explaining_cfg[arg_name] elif not has_default: raise ValueError(f"'explaining_cfg.{arg_name}' undefined") return func(*args, **kwargs) return wrapper