Adding config files

This commit is contained in:
araison 2022-12-20 16:07:11 +01:00
parent 55570a26e5
commit 7fe935dbad
5 changed files with 910 additions and 23 deletions

View File

@ -0,0 +1,302 @@
import functools
import inspect
import logging
import os
import shutil
import warnings
from collections.abc import Iterable
from dataclasses import asdict
from typing import Any
import torch_geometric.graphgym.register as register
from torch_geometric.data.makedirs import makedirs
try: # Define global config object
from yacs.config import CfgNode as CN
eixgnn_cfg = CN()
except ImportError:
eixgnn_cfg = None
warnings.warn(
"Could not define global config object. Please install "
"'yacs' for using the GraphGym experiment manager via "
"'pip install yacs'."
)
def set_eixgnn_cfg(eixgnn_cfg):
r"""
This function sets the default config value.
1) Note that for an experiment, only part of the arguments will be used
The remaining unused arguments won't affect anything.
So feel free to register any argument in graphgym.contrib.config
2) We support *at most* two levels of configs, e.g., eixgnn_cfg.dataset.name
:return: configuration use by the experiment.
"""
if eixgnn_cfg is None:
return eixgnn_cfg
# ----------------------------------------------------------------------- #
# Basic options
# ----------------------------------------------------------------------- #
# Set print destination: stdout / file / both
eixgnn_cfg.print = "both"
eixgnn_cfg.out_dir = "./explanations"
eixgnn_cfg.cfg_dest = "explaining_config.yaml"
eixgnn_cfg.seed = 0
# ----------------------------------------------------------------------- #
# Dataset options
# ----------------------------------------------------------------------- #
eixgnn_cfg.dataset = CN()
eixgnn_cfg.dataset.name = "Cora"
eixgnn_cfg.run_topological_stat = True
# ----------------------------------------------------------------------- #
# Model options
# ----------------------------------------------------------------------- #
eixgnn_cfg.model = CN()
# Set wether or not load the best model for given dataset or a path
eixgnn_cfg.model.ckpt = "best"
# ----------------------------------------------------------------------- #
# Explainer options
# ----------------------------------------------------------------------- #
eixgnn_cfg.explainer = CN()
# Name of the explaining method
eixgnn_cfg.explainer.name = "EiXGNN"
# Whether or not to provide specific explaining methods configuration or default configuration
eixgnn_cfg.explainer.cfg = "default"
# ----------------------------------------------------------------------- #
# Explaining options
# ----------------------------------------------------------------------- #
# 'ExplanationType : 'model' or 'phenomenon'
eixgnn_cfg.explanation_type = "model"
eixgnn_cfg.model_config = CN()
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
eixgnn_cfg.model_config.mode = None
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
eixgnn_cfg.model_config.task_level = None
# Do not modify it, we always assume here that model output are 'raw'
eixgnn_cfg.model_config.return_type = "raw"
eixgnn_cfg.threshold_config = CN()
eixgnn_cfg.threshold_config.threshold_type = None
eixgnn_cfg.threshold_config.value = 0.5
# Set print destination: stdout / file / both
eixgnn_cfg.print = "both"
# Select device: 'cpu', 'cuda', 'auto'
eixgnn_cfg.accelerator = "auto"
# Output directory
eixgnn_cfg.out_dir = "results"
# Config name (in out_dir)
eixgnn_cfg.eixgnn_cfg_dest = "config.yaml"
# Random seed
eixgnn_cfg.seed = 0
# ----------------------------------------------------------------------- #
# Globally shared variables:
# These variables will be set dynamically based on the input dataset
# Do not directly set them here or in .yaml files
# ----------------------------------------------------------------------- #
eixgnn_cfg.share = CN()
# Size of input dimension
eixgnn_cfg.share.dim_in = 1
# Size of out dimension, i.e., number of labels to be predicted
eixgnn_cfg.share.dim_out = 1
# Number of dataset splits: train/val/test
eixgnn_cfg.share.num_splits = 1
# ----------------------------------------------------------------------- #
# Dataset options
# ----------------------------------------------------------------------- #
eixgnn_cfg.dataset = CN()
# Name of the dataset
eixgnn_cfg.dataset.name = "Cora"
# if PyG: look for it in Pytorch Geometric dataset
# if NetworkX/nx: load data in NetworkX format
# Dir to load the dataset. If the dataset is downloaded, this is the
# cache dir
eixgnn_cfg.dataset.dir = "./datasets"
# ----------------------------------------------------------------------- #
# Memory options
# ----------------------------------------------------------------------- #
eixgnn_cfg.mem = CN()
# Perform ReLU inplace
eixgnn_cfg.mem.inplace = False
# Set user customized eixgnn_cfgs
for func in register.config_dict.values():
func(eixgnn_cfg)
def assert_eixgnn_cfg(eixgnn_cfg):
r"""Checks config values, do necessary post processing to the configs"""
if eixgnn_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
raise ValueError(
"Task {} not supported, must be one of node, "
"edge, graph, link_pred".format(eixgnn_cfg.dataset.task)
)
if (
"classification" in eixgnn_cfg.dataset.task_type
and eixgnn_cfg.model.loss_fun == "mse"
):
eixgnn_cfg.model.loss_fun = "cross_entropy"
logging.warning("model.loss_fun changed to cross_entropy for classification.")
if (
eixgnn_cfg.dataset.task_type == "regression"
and eixgnn_cfg.model.loss_fun == "cross_entropy"
):
eixgnn_cfg.model.loss_fun = "mse"
logging.warning("model.loss_fun changed to mse for regression.")
if eixgnn_cfg.dataset.task == "graph" and eixgnn_cfg.dataset.transductive:
eixgnn_cfg.dataset.transductive = False
logging.warning("dataset.transductive changed " "to False for graph task.")
if eixgnn_cfg.gnn.layers_post_mp < 1:
eixgnn_cfg.gnn.layers_post_mp = 1
logging.warning("Layers after message passing should be >=1")
if eixgnn_cfg.gnn.head == "default":
eixgnn_cfg.gnn.head = eixgnn_cfg.dataset.task
eixgnn_cfg.run_dir = eixgnn_cfg.out_dir
def dump_eixgnn_cfg(eixgnn_cfg):
r"""
Dumps the config to the output directory specified in
:obj:`eixgnn_cfg.out_dir`
Args:
eixgnn_cfg (CfgNode): Configuration node
"""
makedirs(eixgnn_cfg.out_dir)
eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest)
with open(eixgnn_cfg_file, "w") as f:
eixgnn_cfg.dump(stream=f)
def load_eixgnn_cfg(eixgnn_cfg, args):
r"""
Load configurations from file system and command line
Args:
eixgnn_cfg (CfgNode): Configuration node
args (ArgumentParser): Command argument parser
"""
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
eixgnn_cfg.merge_from_list(args.opts)
assert_eixgnn_cfg(eixgnn_cfg)
def makedirs_rm_exist(dir):
if os.path.isdir(dir):
shutil.rmtree(dir)
os.makedirs(dir, exist_ok=True)
def get_fname(fname):
r"""
Extract filename from file name path
Args:
fname (string): Filename for the yaml format configuration file
"""
fname = fname.split("/")[-1]
if fname.endswith(".yaml"):
fname = fname[:-5]
elif fname.endswith(".yml"):
fname = fname[:-4]
return fname
def set_out_dir(out_dir, fname):
r"""
Create the directory for full experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
fname = get_fname(fname)
eixgnn_cfg.out_dir = os.path.join(out_dir, fname)
# Make output directory
if eixgnn_cfg.train.auto_resume:
os.makedirs(eixgnn_cfg.out_dir, exist_ok=True)
else:
makedirs_rm_exist(eixgnn_cfg.out_dir)
def set_run_dir(out_dir):
r"""
Create the directory for each random seed experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed))
# Make output directory
if eixgnn_cfg.train.auto_resume:
os.makedirs(eixgnn_cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(eixgnn_cfg.run_dir)
set_eixgnn_cfg(eixgnn_cfg)
def from_config(func):
if inspect.isclass(func):
params = list(inspect.signature(func.__init__).parameters.values())[1:]
else:
params = list(inspect.signature(func).parameters.values())
arg_names = [p.name for p in params]
has_defaults = [p.default != inspect.Parameter.empty for p in params]
@functools.wraps(func)
def wrapper(*args, eixgnn_cfg: Any = None, **kwargs):
if eixgnn_cfg is not None:
eixgnn_cfg = (
dict(eixgnn_cfg)
if isinstance(eixgnn_cfg, Iterable)
else asdict(eixgnn_cfg)
)
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
for arg_name, has_default in iterator:
if arg_name in kwargs:
continue
elif arg_name in eixgnn_cfg:
kwargs[arg_name] = eixgnn_cfg[arg_name]
elif not has_default:
raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined")
return func(*args, **kwargs)
return wrapper

View File

@ -0,0 +1,304 @@
import functools
import inspect
import logging
import os
import shutil
import warnings
from collections.abc import Iterable
from dataclasses import asdict
from typing import Any
import torch_geometric.graphgym.register as register
from torch_geometric.data.makedirs import makedirs
try: # Define global config object
from yacs.config import CfgNode as CN
explaining_cfg = CN()
except ImportError:
explaining_cfg = None
warnings.warn(
"Could not define global config object. Please install "
"'yacs' for using the GraphGym experiment manager via "
"'pip install yacs'."
)
def set_explaining_cfg(explaining_cfg):
r"""
This function sets the default config value.
1) Note that for an experiment, only part of the arguments will be used
The remaining unused arguments won't affect anything.
So feel free to register any argument in graphgym.contrib.config
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
:return: configuration use by the experiment.
"""
if explaining_cfg is None:
return explaining_cfg
# ----------------------------------------------------------------------- #
# Basic options
# ----------------------------------------------------------------------- #
# Set print destination: stdout / file / both
explaining_cfg.print = "both"
explaining_cfg.out_dir = "./explanations"
explaining_cfg.cfg_dest = "explaining_config.yaml"
explaining_cfg.seed = 0
# ----------------------------------------------------------------------- #
# Dataset options
# ----------------------------------------------------------------------- #
explaining_cfg.dataset = CN()
explaining_cfg.dataset.name = "Cora"
explaining_cfg.run_topological_stat = True
# ----------------------------------------------------------------------- #
# Model options
# ----------------------------------------------------------------------- #
explaining_cfg.model = CN()
# Set wether or not load the best model for given dataset or a path
explaining_cfg.model.ckpt = "best"
# ----------------------------------------------------------------------- #
# Explainer options
# ----------------------------------------------------------------------- #
explaining_cfg.explainer = CN()
# Name of the explaining method
explaining_cfg.explainer.name = "EiXGNN"
# Whether or not to provide specific explaining methods configuration or default configuration
explaining_cfg.explainer.cfg = "default"
# ----------------------------------------------------------------------- #
# Explaining options
# ----------------------------------------------------------------------- #
# 'ExplanationType : 'model' or 'phenomenon'
explaining_cfg.explanation_type = "model"
explaining_cfg.model_config = CN()
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
explaining_cfg.model_config.mode = None
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
explaining_cfg.model_config.task_level = None
# Do not modify it, we always assume here that model output are 'raw'
explaining_cfg.model_config.return_type = "raw"
explaining_cfg.threshold_config = CN()
explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold_config.value = 0.5
# Set print destination: stdout / file / both
explaining_cfg.print = "both"
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto"
# Output directory
explaining_cfg.out_dir = "results"
# Config name (in out_dir)
explaining_cfg.explaining_cfg_dest = "config.yaml"
# Random seed
explaining_cfg.seed = 0
# ----------------------------------------------------------------------- #
# Globally shared variables:
# These variables will be set dynamically based on the input dataset
# Do not directly set them here or in .yaml files
# ----------------------------------------------------------------------- #
explaining_cfg.share = CN()
# Size of input dimension
explaining_cfg.share.dim_in = 1
# Size of out dimension, i.e., number of labels to be predicted
explaining_cfg.share.dim_out = 1
# Number of dataset splits: train/val/test
explaining_cfg.share.num_splits = 1
# ----------------------------------------------------------------------- #
# Dataset options
# ----------------------------------------------------------------------- #
explaining_cfg.dataset = CN()
# Name of the dataset
explaining_cfg.dataset.name = "Cora"
# if PyG: look for it in Pytorch Geometric dataset
# if NetworkX/nx: load data in NetworkX format
# Dir to load the dataset. If the dataset is downloaded, this is the
# cache dir
explaining_cfg.dataset.dir = "./datasets"
# ----------------------------------------------------------------------- #
# Memory options
# ----------------------------------------------------------------------- #
explaining_cfg.mem = CN()
# Perform ReLU inplace
explaining_cfg.mem.inplace = False
# Set user customized explaining_cfgs
for func in register.config_dict.values():
func(explaining_cfg)
def assert_explaining_cfg(explaining_cfg):
r"""Checks config values, do necessary post processing to the configs"""
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
raise ValueError(
"Task {} not supported, must be one of node, "
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
)
if (
"classification" in explaining_cfg.dataset.task_type
and explaining_cfg.model.loss_fun == "mse"
):
explaining_cfg.model.loss_fun = "cross_entropy"
logging.warning("model.loss_fun changed to cross_entropy for classification.")
if (
explaining_cfg.dataset.task_type == "regression"
and explaining_cfg.model.loss_fun == "cross_entropy"
):
explaining_cfg.model.loss_fun = "mse"
logging.warning("model.loss_fun changed to mse for regression.")
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
explaining_cfg.dataset.transductive = False
logging.warning("dataset.transductive changed " "to False for graph task.")
if explaining_cfg.gnn.layers_post_mp < 1:
explaining_cfg.gnn.layers_post_mp = 1
logging.warning("Layers after message passing should be >=1")
if explaining_cfg.gnn.head == "default":
explaining_cfg.gnn.head = explaining_cfg.dataset.task
explaining_cfg.run_dir = explaining_cfg.out_dir
def dump_explaining_cfg(explaining_cfg):
r"""
Dumps the config to the output directory specified in
:obj:`explaining_cfg.out_dir`
Args:
explaining_cfg (CfgNode): Configuration node
"""
makedirs(explaining_cfg.out_dir)
explaining_cfg_file = os.path.join(
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
)
with open(explaining_cfg_file, "w") as f:
explaining_cfg.dump(stream=f)
def load_explaining_cfg(explaining_cfg, args):
r"""
Load configurations from file system and command line
Args:
explaining_cfg (CfgNode): Configuration node
args (ArgumentParser): Command argument parser
"""
explaining_cfg.merge_from_file(args.explaining_cfg_file)
explaining_cfg.merge_from_list(args.opts)
assert_explaining_cfg(explaining_cfg)
def makedirs_rm_exist(dir):
if os.path.isdir(dir):
shutil.rmtree(dir)
os.makedirs(dir, exist_ok=True)
def get_fname(fname):
r"""
Extract filename from file name path
Args:
fname (string): Filename for the yaml format configuration file
"""
fname = fname.split("/")[-1]
if fname.endswith(".yaml"):
fname = fname[:-5]
elif fname.endswith(".yml"):
fname = fname[:-4]
return fname
def set_out_dir(out_dir, fname):
r"""
Create the directory for full experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
fname = get_fname(fname)
explaining_cfg.out_dir = os.path.join(out_dir, fname)
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.out_dir)
def set_run_dir(out_dir):
r"""
Create the directory for each random seed experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.run_dir)
set_explaining_cfg(explaining_cfg)
def from_config(func):
if inspect.isclass(func):
params = list(inspect.signature(func.__init__).parameters.values())[1:]
else:
params = list(inspect.signature(func).parameters.values())
arg_names = [p.name for p in params]
has_defaults = [p.default != inspect.Parameter.empty for p in params]
@functools.wraps(func)
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
if explaining_cfg is not None:
explaining_cfg = (
dict(explaining_cfg)
if isinstance(explaining_cfg, Iterable)
else asdict(explaining_cfg)
)
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
for arg_name, has_default in iterator:
if arg_name in kwargs:
continue
elif arg_name in explaining_cfg:
kwargs[arg_name] = explaining_cfg[arg_name]
elif not has_default:
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
return func(*args, **kwargs)
return wrapper

View File

@ -0,0 +1,304 @@
import functools
import inspect
import logging
import os
import shutil
import warnings
from collections.abc import Iterable
from dataclasses import asdict
from typing import Any
import torch_geometric.graphgym.register as register
from torch_geometric.data.makedirs import makedirs
try: # Define global config object
from yacs.config import CfgNode as CN
explaining_cfg = CN()
except ImportError:
explaining_cfg = None
warnings.warn(
"Could not define global config object. Please install "
"'yacs' for using the GraphGym experiment manager via "
"'pip install yacs'."
)
def set_explaining_cfg(explaining_cfg):
r"""
This function sets the default config value.
1) Note that for an experiment, only part of the arguments will be used
The remaining unused arguments won't affect anything.
So feel free to register any argument in graphgym.contrib.config
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
:return: configuration use by the experiment.
"""
if explaining_cfg is None:
return explaining_cfg
# ----------------------------------------------------------------------- #
# Basic options
# ----------------------------------------------------------------------- #
# Set print destination: stdout / file / both
explaining_cfg.print = "both"
explaining_cfg.out_dir = "./explanations"
explaining_cfg.cfg_dest = "explaining_config.yaml"
explaining_cfg.seed = 0
# ----------------------------------------------------------------------- #
# Dataset options
# ----------------------------------------------------------------------- #
explaining_cfg.dataset = CN()
explaining_cfg.dataset.name = "Cora"
explaining_cfg.run_topological_stat = True
# ----------------------------------------------------------------------- #
# Model options
# ----------------------------------------------------------------------- #
explaining_cfg.model = CN()
# Set wether or not load the best model for given dataset or a path
explaining_cfg.model.ckpt = "best"
# ----------------------------------------------------------------------- #
# Explainer options
# ----------------------------------------------------------------------- #
explaining_cfg.explainer = CN()
# Name of the explaining method
explaining_cfg.explainer.name = "EiXGNN"
# Whether or not to provide specific explaining methods configuration or default configuration
explaining_cfg.explainer.cfg = "default"
# ----------------------------------------------------------------------- #
# Explaining options
# ----------------------------------------------------------------------- #
# 'ExplanationType : 'model' or 'phenomenon'
explaining_cfg.explanation_type = "model"
explaining_cfg.model_config = CN()
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
explaining_cfg.model_config.mode = None
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
explaining_cfg.model_config.task_level = None
# Do not modify it, we always assume here that model output are 'raw'
explaining_cfg.model_config.return_type = "raw"
explaining_cfg.threshold_config = CN()
explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold_config.value = 0.5
# Set print destination: stdout / file / both
explaining_cfg.print = "both"
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto"
# Output directory
explaining_cfg.out_dir = "results"
# Config name (in out_dir)
explaining_cfg.explaining_cfg_dest = "config.yaml"
# Random seed
explaining_cfg.seed = 0
# ----------------------------------------------------------------------- #
# Globally shared variables:
# These variables will be set dynamically based on the input dataset
# Do not directly set them here or in .yaml files
# ----------------------------------------------------------------------- #
explaining_cfg.share = CN()
# Size of input dimension
explaining_cfg.share.dim_in = 1
# Size of out dimension, i.e., number of labels to be predicted
explaining_cfg.share.dim_out = 1
# Number of dataset splits: train/val/test
explaining_cfg.share.num_splits = 1
# ----------------------------------------------------------------------- #
# Dataset options
# ----------------------------------------------------------------------- #
explaining_cfg.dataset = CN()
# Name of the dataset
explaining_cfg.dataset.name = "Cora"
# if PyG: look for it in Pytorch Geometric dataset
# if NetworkX/nx: load data in NetworkX format
# Dir to load the dataset. If the dataset is downloaded, this is the
# cache dir
explaining_cfg.dataset.dir = "./datasets"
# ----------------------------------------------------------------------- #
# Memory options
# ----------------------------------------------------------------------- #
explaining_cfg.mem = CN()
# Perform ReLU inplace
explaining_cfg.mem.inplace = False
# Set user customized explaining_cfgs
for func in register.config_dict.values():
func(explaining_cfg)
def assert_explaining_cfg(explaining_cfg):
r"""Checks config values, do necessary post processing to the configs"""
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
raise ValueError(
"Task {} not supported, must be one of node, "
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
)
if (
"classification" in explaining_cfg.dataset.task_type
and explaining_cfg.model.loss_fun == "mse"
):
explaining_cfg.model.loss_fun = "cross_entropy"
logging.warning("model.loss_fun changed to cross_entropy for classification.")
if (
explaining_cfg.dataset.task_type == "regression"
and explaining_cfg.model.loss_fun == "cross_entropy"
):
explaining_cfg.model.loss_fun = "mse"
logging.warning("model.loss_fun changed to mse for regression.")
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
explaining_cfg.dataset.transductive = False
logging.warning("dataset.transductive changed " "to False for graph task.")
if explaining_cfg.gnn.layers_post_mp < 1:
explaining_cfg.gnn.layers_post_mp = 1
logging.warning("Layers after message passing should be >=1")
if explaining_cfg.gnn.head == "default":
explaining_cfg.gnn.head = explaining_cfg.dataset.task
explaining_cfg.run_dir = explaining_cfg.out_dir
def dump_explaining_cfg(explaining_cfg):
r"""
Dumps the config to the output directory specified in
:obj:`explaining_cfg.out_dir`
Args:
explaining_cfg (CfgNode): Configuration node
"""
makedirs(explaining_cfg.out_dir)
explaining_cfg_file = os.path.join(
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
)
with open(explaining_cfg_file, "w") as f:
explaining_cfg.dump(stream=f)
def load_explaining_cfg(explaining_cfg, args):
r"""
Load configurations from file system and command line
Args:
explaining_cfg (CfgNode): Configuration node
args (ArgumentParser): Command argument parser
"""
explaining_cfg.merge_from_file(args.explaining_cfg_file)
explaining_cfg.merge_from_list(args.opts)
assert_explaining_cfg(explaining_cfg)
def makedirs_rm_exist(dir):
if os.path.isdir(dir):
shutil.rmtree(dir)
os.makedirs(dir, exist_ok=True)
def get_fname(fname):
r"""
Extract filename from file name path
Args:
fname (string): Filename for the yaml format configuration file
"""
fname = fname.split("/")[-1]
if fname.endswith(".yaml"):
fname = fname[:-5]
elif fname.endswith(".yml"):
fname = fname[:-4]
return fname
def set_out_dir(out_dir, fname):
r"""
Create the directory for full experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
fname = get_fname(fname)
explaining_cfg.out_dir = os.path.join(out_dir, fname)
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.out_dir)
def set_run_dir(out_dir):
r"""
Create the directory for each random seed experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.run_dir)
set_explaining_cfg(explaining_cfg)
def from_config(func):
if inspect.isclass(func):
params = list(inspect.signature(func.__init__).parameters.values())[1:]
else:
params = list(inspect.signature(func).parameters.values())
arg_names = [p.name for p in params]
has_defaults = [p.default != inspect.Parameter.empty for p in params]
@functools.wraps(func)
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
if explaining_cfg is not None:
explaining_cfg = (
dict(explaining_cfg)
if isinstance(explaining_cfg, Iterable)
else asdict(explaining_cfg)
)
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
for arg_name, has_default in iterator:
if arg_name in kwargs:
continue
elif arg_name in explaining_cfg:
kwargs[arg_name] = explaining_cfg[arg_name]
elif not has_default:
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
return func(*args, **kwargs)
return wrapper

View File

@ -1,16 +0,0 @@
import copy
from torch import FloatTensor
from torch.nn import ReLU
def relu_mask(explanation: Explanation) -> Explanation:
relu = ReLU()
explanation_store = explanation._store
raw_data = copy.copy(explanation._store)
for k, v in explanation_store.items():
if "mask" in k:
explanation_store[k] = relu(v)
explanation.__setattr__("raw_explanation", raw_data)
explanation.__setattr__("raw_explanation_transform", "relu")
return explanation

View File

@ -1,7 +0,0 @@
import copy
from typing import Dict, List, Optional, Union
import torch
from torch import Tensor
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
from torch_geometric.explain.explanation import Explanation