Adding config files
This commit is contained in:
parent
55570a26e5
commit
7fe935dbad
302
explaining_framework/config/explainer_config/eixgnn_config.py
Normal file
302
explaining_framework/config/explainer_config/eixgnn_config.py
Normal file
@ -0,0 +1,302 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import torch_geometric.graphgym.register as register
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
|
||||
try: # Define global config object
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
eixgnn_cfg = CN()
|
||||
except ImportError:
|
||||
eixgnn_cfg = None
|
||||
warnings.warn(
|
||||
"Could not define global config object. Please install "
|
||||
"'yacs' for using the GraphGym experiment manager via "
|
||||
"'pip install yacs'."
|
||||
)
|
||||
|
||||
|
||||
def set_eixgnn_cfg(eixgnn_cfg):
|
||||
r"""
|
||||
This function sets the default config value.
|
||||
1) Note that for an experiment, only part of the arguments will be used
|
||||
The remaining unused arguments won't affect anything.
|
||||
So feel free to register any argument in graphgym.contrib.config
|
||||
2) We support *at most* two levels of configs, e.g., eixgnn_cfg.dataset.name
|
||||
:return: configuration use by the experiment.
|
||||
"""
|
||||
if eixgnn_cfg is None:
|
||||
return eixgnn_cfg
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Basic options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
eixgnn_cfg.print = "both"
|
||||
|
||||
eixgnn_cfg.out_dir = "./explanations"
|
||||
|
||||
eixgnn_cfg.cfg_dest = "explaining_config.yaml"
|
||||
|
||||
eixgnn_cfg.seed = 0
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.dataset = CN()
|
||||
|
||||
eixgnn_cfg.dataset.name = "Cora"
|
||||
|
||||
eixgnn_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Model options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.model = CN()
|
||||
|
||||
# Set wether or not load the best model for given dataset or a path
|
||||
eixgnn_cfg.model.ckpt = "best"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explainer options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.explainer = CN()
|
||||
|
||||
# Name of the explaining method
|
||||
eixgnn_cfg.explainer.name = "EiXGNN"
|
||||
|
||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
eixgnn_cfg.explainer.cfg = "default"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# 'ExplanationType : 'model' or 'phenomenon'
|
||||
eixgnn_cfg.explanation_type = "model"
|
||||
|
||||
eixgnn_cfg.model_config = CN()
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
eixgnn_cfg.model_config.mode = None
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
eixgnn_cfg.model_config.task_level = None
|
||||
|
||||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
eixgnn_cfg.model_config.return_type = "raw"
|
||||
|
||||
eixgnn_cfg.threshold_config = CN()
|
||||
|
||||
eixgnn_cfg.threshold_config.threshold_type = None
|
||||
|
||||
eixgnn_cfg.threshold_config.value = 0.5
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
eixgnn_cfg.print = "both"
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
eixgnn_cfg.accelerator = "auto"
|
||||
|
||||
# Output directory
|
||||
eixgnn_cfg.out_dir = "results"
|
||||
|
||||
# Config name (in out_dir)
|
||||
eixgnn_cfg.eixgnn_cfg_dest = "config.yaml"
|
||||
|
||||
# Random seed
|
||||
eixgnn_cfg.seed = 0
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Globally shared variables:
|
||||
# These variables will be set dynamically based on the input dataset
|
||||
# Do not directly set them here or in .yaml files
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.share = CN()
|
||||
|
||||
# Size of input dimension
|
||||
eixgnn_cfg.share.dim_in = 1
|
||||
|
||||
# Size of out dimension, i.e., number of labels to be predicted
|
||||
eixgnn_cfg.share.dim_out = 1
|
||||
|
||||
# Number of dataset splits: train/val/test
|
||||
eixgnn_cfg.share.num_splits = 1
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
eixgnn_cfg.dataset = CN()
|
||||
|
||||
# Name of the dataset
|
||||
eixgnn_cfg.dataset.name = "Cora"
|
||||
|
||||
# if PyG: look for it in Pytorch Geometric dataset
|
||||
# if NetworkX/nx: load data in NetworkX format
|
||||
|
||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
||||
# cache dir
|
||||
eixgnn_cfg.dataset.dir = "./datasets"
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Memory options
|
||||
# ----------------------------------------------------------------------- #
|
||||
eixgnn_cfg.mem = CN()
|
||||
|
||||
# Perform ReLU inplace
|
||||
eixgnn_cfg.mem.inplace = False
|
||||
|
||||
# Set user customized eixgnn_cfgs
|
||||
for func in register.config_dict.values():
|
||||
func(eixgnn_cfg)
|
||||
|
||||
|
||||
def assert_eixgnn_cfg(eixgnn_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
if eixgnn_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(eixgnn_cfg.dataset.task)
|
||||
)
|
||||
if (
|
||||
"classification" in eixgnn_cfg.dataset.task_type
|
||||
and eixgnn_cfg.model.loss_fun == "mse"
|
||||
):
|
||||
eixgnn_cfg.model.loss_fun = "cross_entropy"
|
||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
||||
if (
|
||||
eixgnn_cfg.dataset.task_type == "regression"
|
||||
and eixgnn_cfg.model.loss_fun == "cross_entropy"
|
||||
):
|
||||
eixgnn_cfg.model.loss_fun = "mse"
|
||||
logging.warning("model.loss_fun changed to mse for regression.")
|
||||
if eixgnn_cfg.dataset.task == "graph" and eixgnn_cfg.dataset.transductive:
|
||||
eixgnn_cfg.dataset.transductive = False
|
||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
||||
if eixgnn_cfg.gnn.layers_post_mp < 1:
|
||||
eixgnn_cfg.gnn.layers_post_mp = 1
|
||||
logging.warning("Layers after message passing should be >=1")
|
||||
if eixgnn_cfg.gnn.head == "default":
|
||||
eixgnn_cfg.gnn.head = eixgnn_cfg.dataset.task
|
||||
eixgnn_cfg.run_dir = eixgnn_cfg.out_dir
|
||||
|
||||
|
||||
def dump_eixgnn_cfg(eixgnn_cfg):
|
||||
r"""
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`eixgnn_cfg.out_dir`
|
||||
Args:
|
||||
eixgnn_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(eixgnn_cfg.out_dir)
|
||||
eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest)
|
||||
with open(eixgnn_cfg_file, "w") as f:
|
||||
eixgnn_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
eixgnn_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||
eixgnn_cfg.merge_from_list(args.opts)
|
||||
assert_eixgnn_cfg(eixgnn_cfg)
|
||||
|
||||
|
||||
def makedirs_rm_exist(dir):
|
||||
if os.path.isdir(dir):
|
||||
shutil.rmtree(dir)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
|
||||
def get_fname(fname):
|
||||
r"""
|
||||
Extract filename from file name path
|
||||
Args:
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = fname.split("/")[-1]
|
||||
if fname.endswith(".yaml"):
|
||||
fname = fname[:-5]
|
||||
elif fname.endswith(".yml"):
|
||||
fname = fname[:-4]
|
||||
return fname
|
||||
|
||||
|
||||
def set_out_dir(out_dir, fname):
|
||||
r"""
|
||||
Create the directory for full experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = get_fname(fname)
|
||||
eixgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
if eixgnn_cfg.train.auto_resume:
|
||||
os.makedirs(eixgnn_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(eixgnn_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
r"""
|
||||
Create the directory for each random seed experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed))
|
||||
# Make output directory
|
||||
if eixgnn_cfg.train.auto_resume:
|
||||
os.makedirs(eixgnn_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(eixgnn_cfg.run_dir)
|
||||
|
||||
|
||||
set_eixgnn_cfg(eixgnn_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
if inspect.isclass(func):
|
||||
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||
else:
|
||||
params = list(inspect.signature(func).parameters.values())
|
||||
|
||||
arg_names = [p.name for p in params]
|
||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, eixgnn_cfg: Any = None, **kwargs):
|
||||
if eixgnn_cfg is not None:
|
||||
eixgnn_cfg = (
|
||||
dict(eixgnn_cfg)
|
||||
if isinstance(eixgnn_cfg, Iterable)
|
||||
else asdict(eixgnn_cfg)
|
||||
)
|
||||
|
||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
for arg_name, has_default in iterator:
|
||||
if arg_name in kwargs:
|
||||
continue
|
||||
elif arg_name in eixgnn_cfg:
|
||||
kwargs[arg_name] = eixgnn_cfg[arg_name]
|
||||
elif not has_default:
|
||||
raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
304
explaining_framework/config/explainer_config/scgnn_config.py
Normal file
304
explaining_framework/config/explainer_config/scgnn_config.py
Normal file
@ -0,0 +1,304 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import torch_geometric.graphgym.register as register
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
|
||||
try: # Define global config object
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
explaining_cfg = CN()
|
||||
except ImportError:
|
||||
explaining_cfg = None
|
||||
warnings.warn(
|
||||
"Could not define global config object. Please install "
|
||||
"'yacs' for using the GraphGym experiment manager via "
|
||||
"'pip install yacs'."
|
||||
)
|
||||
|
||||
|
||||
def set_explaining_cfg(explaining_cfg):
|
||||
r"""
|
||||
This function sets the default config value.
|
||||
1) Note that for an experiment, only part of the arguments will be used
|
||||
The remaining unused arguments won't affect anything.
|
||||
So feel free to register any argument in graphgym.contrib.config
|
||||
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
|
||||
:return: configuration use by the experiment.
|
||||
"""
|
||||
if explaining_cfg is None:
|
||||
return explaining_cfg
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Basic options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
|
||||
explaining_cfg.out_dir = "./explanations"
|
||||
|
||||
explaining_cfg.cfg_dest = "explaining_config.yaml"
|
||||
|
||||
explaining_cfg.seed = 0
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
explaining_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Model options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.model = CN()
|
||||
|
||||
# Set wether or not load the best model for given dataset or a path
|
||||
explaining_cfg.model.ckpt = "best"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explainer options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.explainer = CN()
|
||||
|
||||
# Name of the explaining method
|
||||
explaining_cfg.explainer.name = "EiXGNN"
|
||||
|
||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
explaining_cfg.explainer.cfg = "default"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# 'ExplanationType : 'model' or 'phenomenon'
|
||||
explaining_cfg.explanation_type = "model"
|
||||
|
||||
explaining_cfg.model_config = CN()
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
explaining_cfg.model_config.mode = None
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
explaining_cfg.model_config.task_level = None
|
||||
|
||||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
explaining_cfg.model_config.return_type = "raw"
|
||||
|
||||
explaining_cfg.threshold_config = CN()
|
||||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
|
||||
explaining_cfg.threshold_config.value = 0.5
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# Output directory
|
||||
explaining_cfg.out_dir = "results"
|
||||
|
||||
# Config name (in out_dir)
|
||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||
|
||||
# Random seed
|
||||
explaining_cfg.seed = 0
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Globally shared variables:
|
||||
# These variables will be set dynamically based on the input dataset
|
||||
# Do not directly set them here or in .yaml files
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.share = CN()
|
||||
|
||||
# Size of input dimension
|
||||
explaining_cfg.share.dim_in = 1
|
||||
|
||||
# Size of out dimension, i.e., number of labels to be predicted
|
||||
explaining_cfg.share.dim_out = 1
|
||||
|
||||
# Number of dataset splits: train/val/test
|
||||
explaining_cfg.share.num_splits = 1
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
# Name of the dataset
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
# if PyG: look for it in Pytorch Geometric dataset
|
||||
# if NetworkX/nx: load data in NetworkX format
|
||||
|
||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
||||
# cache dir
|
||||
explaining_cfg.dataset.dir = "./datasets"
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Memory options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.mem = CN()
|
||||
|
||||
# Perform ReLU inplace
|
||||
explaining_cfg.mem.inplace = False
|
||||
|
||||
# Set user customized explaining_cfgs
|
||||
for func in register.config_dict.values():
|
||||
func(explaining_cfg)
|
||||
|
||||
|
||||
def assert_explaining_cfg(explaining_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
|
||||
)
|
||||
if (
|
||||
"classification" in explaining_cfg.dataset.task_type
|
||||
and explaining_cfg.model.loss_fun == "mse"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "cross_entropy"
|
||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
||||
if (
|
||||
explaining_cfg.dataset.task_type == "regression"
|
||||
and explaining_cfg.model.loss_fun == "cross_entropy"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "mse"
|
||||
logging.warning("model.loss_fun changed to mse for regression.")
|
||||
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
|
||||
explaining_cfg.dataset.transductive = False
|
||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
||||
if explaining_cfg.gnn.layers_post_mp < 1:
|
||||
explaining_cfg.gnn.layers_post_mp = 1
|
||||
logging.warning("Layers after message passing should be >=1")
|
||||
if explaining_cfg.gnn.head == "default":
|
||||
explaining_cfg.gnn.head = explaining_cfg.dataset.task
|
||||
explaining_cfg.run_dir = explaining_cfg.out_dir
|
||||
|
||||
|
||||
def dump_explaining_cfg(explaining_cfg):
|
||||
r"""
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`explaining_cfg.out_dir`
|
||||
Args:
|
||||
explaining_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(explaining_cfg.out_dir)
|
||||
explaining_cfg_file = os.path.join(
|
||||
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
|
||||
)
|
||||
with open(explaining_cfg_file, "w") as f:
|
||||
explaining_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_explaining_cfg(explaining_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
explaining_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
explaining_cfg.merge_from_file(args.explaining_cfg_file)
|
||||
explaining_cfg.merge_from_list(args.opts)
|
||||
assert_explaining_cfg(explaining_cfg)
|
||||
|
||||
|
||||
def makedirs_rm_exist(dir):
|
||||
if os.path.isdir(dir):
|
||||
shutil.rmtree(dir)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
|
||||
def get_fname(fname):
|
||||
r"""
|
||||
Extract filename from file name path
|
||||
Args:
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = fname.split("/")[-1]
|
||||
if fname.endswith(".yaml"):
|
||||
fname = fname[:-5]
|
||||
elif fname.endswith(".yml"):
|
||||
fname = fname[:-4]
|
||||
return fname
|
||||
|
||||
|
||||
def set_out_dir(out_dir, fname):
|
||||
r"""
|
||||
Create the directory for full experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = get_fname(fname)
|
||||
explaining_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
r"""
|
||||
Create the directory for each random seed experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
|
||||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||
|
||||
|
||||
set_explaining_cfg(explaining_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
if inspect.isclass(func):
|
||||
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||
else:
|
||||
params = list(inspect.signature(func).parameters.values())
|
||||
|
||||
arg_names = [p.name for p in params]
|
||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
|
||||
if explaining_cfg is not None:
|
||||
explaining_cfg = (
|
||||
dict(explaining_cfg)
|
||||
if isinstance(explaining_cfg, Iterable)
|
||||
else asdict(explaining_cfg)
|
||||
)
|
||||
|
||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
for arg_name, has_default in iterator:
|
||||
if arg_name in kwargs:
|
||||
continue
|
||||
elif arg_name in explaining_cfg:
|
||||
kwargs[arg_name] = explaining_cfg[arg_name]
|
||||
elif not has_default:
|
||||
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
304
explaining_framework/config/explaining_config.py
Normal file
304
explaining_framework/config/explaining_config.py
Normal file
@ -0,0 +1,304 @@
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import torch_geometric.graphgym.register as register
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
|
||||
try: # Define global config object
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
explaining_cfg = CN()
|
||||
except ImportError:
|
||||
explaining_cfg = None
|
||||
warnings.warn(
|
||||
"Could not define global config object. Please install "
|
||||
"'yacs' for using the GraphGym experiment manager via "
|
||||
"'pip install yacs'."
|
||||
)
|
||||
|
||||
|
||||
def set_explaining_cfg(explaining_cfg):
|
||||
r"""
|
||||
This function sets the default config value.
|
||||
1) Note that for an experiment, only part of the arguments will be used
|
||||
The remaining unused arguments won't affect anything.
|
||||
So feel free to register any argument in graphgym.contrib.config
|
||||
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
|
||||
:return: configuration use by the experiment.
|
||||
"""
|
||||
if explaining_cfg is None:
|
||||
return explaining_cfg
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Basic options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
|
||||
explaining_cfg.out_dir = "./explanations"
|
||||
|
||||
explaining_cfg.cfg_dest = "explaining_config.yaml"
|
||||
|
||||
explaining_cfg.seed = 0
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
explaining_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Model options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.model = CN()
|
||||
|
||||
# Set wether or not load the best model for given dataset or a path
|
||||
explaining_cfg.model.ckpt = "best"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explainer options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.explainer = CN()
|
||||
|
||||
# Name of the explaining method
|
||||
explaining_cfg.explainer.name = "EiXGNN"
|
||||
|
||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
explaining_cfg.explainer.cfg = "default"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# 'ExplanationType : 'model' or 'phenomenon'
|
||||
explaining_cfg.explanation_type = "model"
|
||||
|
||||
explaining_cfg.model_config = CN()
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
explaining_cfg.model_config.mode = None
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
explaining_cfg.model_config.task_level = None
|
||||
|
||||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
explaining_cfg.model_config.return_type = "raw"
|
||||
|
||||
explaining_cfg.threshold_config = CN()
|
||||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
|
||||
explaining_cfg.threshold_config.value = 0.5
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# Output directory
|
||||
explaining_cfg.out_dir = "results"
|
||||
|
||||
# Config name (in out_dir)
|
||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||
|
||||
# Random seed
|
||||
explaining_cfg.seed = 0
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Globally shared variables:
|
||||
# These variables will be set dynamically based on the input dataset
|
||||
# Do not directly set them here or in .yaml files
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.share = CN()
|
||||
|
||||
# Size of input dimension
|
||||
explaining_cfg.share.dim_in = 1
|
||||
|
||||
# Size of out dimension, i.e., number of labels to be predicted
|
||||
explaining_cfg.share.dim_out = 1
|
||||
|
||||
# Number of dataset splits: train/val/test
|
||||
explaining_cfg.share.num_splits = 1
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
# Name of the dataset
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
# if PyG: look for it in Pytorch Geometric dataset
|
||||
# if NetworkX/nx: load data in NetworkX format
|
||||
|
||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
||||
# cache dir
|
||||
explaining_cfg.dataset.dir = "./datasets"
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Memory options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.mem = CN()
|
||||
|
||||
# Perform ReLU inplace
|
||||
explaining_cfg.mem.inplace = False
|
||||
|
||||
# Set user customized explaining_cfgs
|
||||
for func in register.config_dict.values():
|
||||
func(explaining_cfg)
|
||||
|
||||
|
||||
def assert_explaining_cfg(explaining_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
|
||||
)
|
||||
if (
|
||||
"classification" in explaining_cfg.dataset.task_type
|
||||
and explaining_cfg.model.loss_fun == "mse"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "cross_entropy"
|
||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
||||
if (
|
||||
explaining_cfg.dataset.task_type == "regression"
|
||||
and explaining_cfg.model.loss_fun == "cross_entropy"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "mse"
|
||||
logging.warning("model.loss_fun changed to mse for regression.")
|
||||
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
|
||||
explaining_cfg.dataset.transductive = False
|
||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
||||
if explaining_cfg.gnn.layers_post_mp < 1:
|
||||
explaining_cfg.gnn.layers_post_mp = 1
|
||||
logging.warning("Layers after message passing should be >=1")
|
||||
if explaining_cfg.gnn.head == "default":
|
||||
explaining_cfg.gnn.head = explaining_cfg.dataset.task
|
||||
explaining_cfg.run_dir = explaining_cfg.out_dir
|
||||
|
||||
|
||||
def dump_explaining_cfg(explaining_cfg):
|
||||
r"""
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`explaining_cfg.out_dir`
|
||||
Args:
|
||||
explaining_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(explaining_cfg.out_dir)
|
||||
explaining_cfg_file = os.path.join(
|
||||
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
|
||||
)
|
||||
with open(explaining_cfg_file, "w") as f:
|
||||
explaining_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_explaining_cfg(explaining_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
explaining_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
explaining_cfg.merge_from_file(args.explaining_cfg_file)
|
||||
explaining_cfg.merge_from_list(args.opts)
|
||||
assert_explaining_cfg(explaining_cfg)
|
||||
|
||||
|
||||
def makedirs_rm_exist(dir):
|
||||
if os.path.isdir(dir):
|
||||
shutil.rmtree(dir)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
|
||||
def get_fname(fname):
|
||||
r"""
|
||||
Extract filename from file name path
|
||||
Args:
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = fname.split("/")[-1]
|
||||
if fname.endswith(".yaml"):
|
||||
fname = fname[:-5]
|
||||
elif fname.endswith(".yml"):
|
||||
fname = fname[:-4]
|
||||
return fname
|
||||
|
||||
|
||||
def set_out_dir(out_dir, fname):
|
||||
r"""
|
||||
Create the directory for full experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = get_fname(fname)
|
||||
explaining_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
r"""
|
||||
Create the directory for each random seed experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
|
||||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||
|
||||
|
||||
set_explaining_cfg(explaining_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
if inspect.isclass(func):
|
||||
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||
else:
|
||||
params = list(inspect.signature(func).parameters.values())
|
||||
|
||||
arg_names = [p.name for p in params]
|
||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
|
||||
if explaining_cfg is not None:
|
||||
explaining_cfg = (
|
||||
dict(explaining_cfg)
|
||||
if isinstance(explaining_cfg, Iterable)
|
||||
else asdict(explaining_cfg)
|
||||
)
|
||||
|
||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
for arg_name, has_default in iterator:
|
||||
if arg_name in kwargs:
|
||||
continue
|
||||
elif arg_name in explaining_cfg:
|
||||
kwargs[arg_name] = explaining_cfg[arg_name]
|
||||
elif not has_default:
|
||||
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
@ -1,16 +0,0 @@
|
||||
import copy
|
||||
|
||||
from torch import FloatTensor
|
||||
from torch.nn import ReLU
|
||||
|
||||
|
||||
def relu_mask(explanation: Explanation) -> Explanation:
|
||||
relu = ReLU()
|
||||
explanation_store = explanation._store
|
||||
raw_data = copy.copy(explanation._store)
|
||||
for k, v in explanation_store.items():
|
||||
if "mask" in k:
|
||||
explanation_store[k] = relu(v)
|
||||
explanation.__setattr__("raw_explanation", raw_data)
|
||||
explanation.__setattr__("raw_explanation_transform", "relu")
|
||||
return explanation
|
@ -1,7 +0,0 @@
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
|
||||
from torch_geometric.explain.explanation import Explanation
|
Loading…
Reference in New Issue
Block a user