252 lines
8.1 KiB
Python
252 lines
8.1 KiB
Python
import functools
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import warnings
|
|
from collections.abc import Iterable
|
|
from dataclasses import asdict
|
|
from typing import Any
|
|
|
|
import torch_geometric.graphgym.register as register
|
|
from torch_geometric.data.makedirs import makedirs
|
|
|
|
try: # Define global config object
|
|
from yacs.config import CfgNode as CN
|
|
|
|
explaining_cfg = CN()
|
|
except ImportError:
|
|
explaining_cfg = None
|
|
warnings.warn(
|
|
"Could not define global config object. Please install "
|
|
"'yacs' for using the GraphGym experiment manager via "
|
|
"'pip install yacs'."
|
|
)
|
|
|
|
|
|
def set_cfg(explaining_cfg):
|
|
r"""
|
|
This function sets the default config value.
|
|
1) Note that for an experiment, only part of the arguments will be used
|
|
The remaining unused arguments won't affect anything.
|
|
So feel free to register any argument in graphgym.contrib.config
|
|
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
|
|
:return: configuration use by the experiment.
|
|
"""
|
|
if explaining_cfg is None:
|
|
return explaining_cfg
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
# Basic options
|
|
# ----------------------------------------------------------------------- #
|
|
|
|
# Set print destination: stdout / file / both
|
|
explaining_cfg.print = "both"
|
|
|
|
explaining_cfg.out_dir = "./explanations"
|
|
|
|
explaining_cfg.cfg_dest = "explaining_config.yaml"
|
|
|
|
explaining_cfg.seed = 0
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
# Dataset options
|
|
# ----------------------------------------------------------------------- #
|
|
|
|
explaining_cfg.dataset = CN()
|
|
|
|
explaining_cfg.dataset.name = "Cora"
|
|
|
|
explaining_cfg.dataset.item = None
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
# Model options
|
|
# ----------------------------------------------------------------------- #
|
|
|
|
explaining_cfg.model = CN()
|
|
|
|
# Set wether or not load the best model for given dataset or a path
|
|
explaining_cfg.model.ckpt = "best"
|
|
|
|
# Setting the path of models folder
|
|
explaining_cfg.model.path = "path"
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
# Explainer options
|
|
# ----------------------------------------------------------------------- #
|
|
|
|
explaining_cfg.explainer = CN()
|
|
|
|
# Name of the explaining method
|
|
explaining_cfg.explainer.name = "EiXGNN"
|
|
|
|
# Whether or not to provide specific explaining methods configuration or default configuration
|
|
explaining_cfg.explainer.cfg = "default"
|
|
|
|
# Whether or not recomputing explanation if they already exist
|
|
explaining_cfg.explainer.force = False
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
# Explaining options
|
|
# ----------------------------------------------------------------------- #
|
|
|
|
# 'ExplanationType : 'model' or 'phenomenon'
|
|
explaining_cfg.explanation_type = "model"
|
|
|
|
explaining_cfg.model_config = CN()
|
|
|
|
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
|
explaining_cfg.model_config.mode = "regression"
|
|
|
|
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
|
explaining_cfg.model_config.task_level = None
|
|
|
|
# Do not modify it, we always assume here that model output are 'raw'
|
|
explaining_cfg.model_config.return_type = "raw"
|
|
|
|
# ----------------------------------------------------------------------- #
|
|
# Thresholding options
|
|
# ----------------------------------------------------------------------- #
|
|
|
|
explaining_cfg.threshold = CN()
|
|
|
|
explaining_cfg.threshold.config = CN()
|
|
explaining_cfg.threshold.config.type = "all"
|
|
|
|
explaining_cfg.threshold.value = CN()
|
|
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)]
|
|
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
|
|
|
# which objectives metrics to computes, either all or one in particular if implemented
|
|
explaining_cfg.metrics = CN()
|
|
explaining_cfg.metrics.sparsity = CN()
|
|
explaining_cfg.metrics.sparsity.name = "all"
|
|
explaining_cfg.metrics.fidelity = CN()
|
|
explaining_cfg.metrics.fidelity.name = "all"
|
|
explaining_cfg.metrics.accuracy = CN()
|
|
explaining_cfg.metrics.accuracy.name = "all"
|
|
|
|
# Whether or not recomputing metrics if they already exist
|
|
|
|
explaining_cfg.adjust = CN()
|
|
explaining_cfg.adjust.strategy = "rpns"
|
|
|
|
explaining_cfg.attack = CN()
|
|
explaining_cfg.attack.name = "all"
|
|
|
|
# Select device: 'cpu', 'cuda', 'auto'
|
|
explaining_cfg.accelerator = "auto"
|
|
|
|
|
|
def assert_cfg(explaining_cfg):
|
|
r"""Checks config values, do necessary post processing to the configs"""
|
|
explaining_cfg.run_dir = explaining_cfg.out_dir
|
|
|
|
|
|
def dump_cfg(explaining_cfg):
|
|
r"""
|
|
Dumps the config to the output directory specified in
|
|
:obj:`explaining_cfg.out_dir`
|
|
Args:
|
|
explaining_cfg (CfgNode): Configuration node
|
|
"""
|
|
makedirs(explaining_cfg.out_dir)
|
|
explaining_cfg_file = os.path.join(
|
|
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
|
|
)
|
|
with open(explaining_cfg_file, "w") as f:
|
|
explaining_cfg.dump(stream=f)
|
|
|
|
|
|
def load_cfg(explaining_cfg, args):
|
|
r"""
|
|
Load configurations from file system and command line
|
|
Args:
|
|
explaining_cfg (CfgNode): Configuration node
|
|
args (ArgumentParser): Command argument parser
|
|
"""
|
|
explaining_cfg.merge_from_file(args.explaining_cfg_file)
|
|
explaining_cfg.merge_from_list(args.opts)
|
|
assert_explaining_cfg(explaining_cfg)
|
|
|
|
|
|
def makedirs_rm_exist(dir):
|
|
if os.path.isdir(dir):
|
|
shutil.rmtree(dir)
|
|
os.makedirs(dir, exist_ok=True)
|
|
|
|
|
|
def get_fname(fname):
|
|
r"""
|
|
Extract filename from file name path
|
|
Args:
|
|
fname (string): Filename for the yaml format configuration file
|
|
"""
|
|
fname = fname.split("/")[-1]
|
|
if fname.endswith(".yaml"):
|
|
fname = fname[:-5]
|
|
elif fname.endswith(".yml"):
|
|
fname = fname[:-4]
|
|
return fname
|
|
|
|
|
|
def set_out_dir(out_dir, fname):
|
|
r"""
|
|
Create the directory for full experiment run
|
|
Args:
|
|
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
|
fname (string): Filename for the yaml format configuration file
|
|
"""
|
|
fname = get_fname(fname)
|
|
explaining_cfg.out_dir = os.path.join(out_dir, fname)
|
|
# Make output directory
|
|
if explaining_cfg.train.auto_resume:
|
|
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
|
|
|
|
|
def set_run_dir(out_dir):
|
|
r"""
|
|
Create the directory for each random seed experiment run
|
|
Args:
|
|
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
|
fname (string): Filename for the yaml format configuration file
|
|
"""
|
|
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
|
|
# Make output directory
|
|
if explaining_cfg.train.auto_resume:
|
|
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
|
|
|
|
|
set_cfg(explaining_cfg)
|
|
|
|
|
|
def from_config(func):
|
|
if inspect.isclass(func):
|
|
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
|
else:
|
|
params = list(inspect.signature(func).parameters.values())
|
|
|
|
arg_names = [p.name for p in params]
|
|
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
|
|
if explaining_cfg is not None:
|
|
explaining_cfg = (
|
|
dict(explaining_cfg)
|
|
if isinstance(explaining_cfg, Iterable)
|
|
else asdict(explaining_cfg)
|
|
)
|
|
|
|
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
|
for arg_name, has_default in iterator:
|
|
if arg_name in kwargs:
|
|
continue
|
|
elif arg_name in explaining_cfg:
|
|
kwargs[arg_name] = explaining_cfg[arg_name]
|
|
elif not has_default:
|
|
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|