New features
This commit is contained in:
parent
7f540f53d7
commit
495cde4c70
5 changed files with 215 additions and 516 deletions
|
@ -36,267 +36,125 @@ def set_eixgnn_cfg(eixgnn_cfg):
|
||||||
if eixgnn_cfg is None:
|
if eixgnn_cfg is None:
|
||||||
return eixgnn_cfg
|
return eixgnn_cfg
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Basic options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
# Set print destination: stdout / file / both
|
|
||||||
eixgnn_cfg.print = "both"
|
|
||||||
|
|
||||||
eixgnn_cfg.out_dir = "./explanations"
|
|
||||||
|
|
||||||
eixgnn_cfg.cfg_dest = "explaining_config.yaml"
|
|
||||||
|
|
||||||
eixgnn_cfg.seed = 0
|
eixgnn_cfg.seed = 0
|
||||||
|
eixgnn_cfg.L = 50
|
||||||
# ----------------------------------------------------------------------- #
|
eixgnn_cfg.p = 0.5
|
||||||
# Dataset options
|
eixgnn_cfg.importance_sampling_strategy = "node"
|
||||||
# ----------------------------------------------------------------------- #
|
eixgnn_cfg.domain_similarity = "relative_edge_density"
|
||||||
|
eixgnn_cfg.signal_similarity = "KL"
|
||||||
eixgnn_cfg.dataset = CN()
|
eixgnn_cfg.shapley_value_approx = 100
|
||||||
|
|
||||||
eixgnn_cfg.dataset.name = "Cora"
|
|
||||||
|
|
||||||
eixgnn_cfg.run_topological_stat = True
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Model options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
eixgnn_cfg.model = CN()
|
|
||||||
|
|
||||||
# Set wether or not load the best model for given dataset or a path
|
|
||||||
eixgnn_cfg.model.ckpt = "best"
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Explainer options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
eixgnn_cfg.explainer = CN()
|
|
||||||
|
|
||||||
# Name of the explaining method
|
|
||||||
eixgnn_cfg.explainer.name = "EiXGNN"
|
|
||||||
|
|
||||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
|
||||||
eixgnn_cfg.explainer.cfg = "default"
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Explaining options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
# 'ExplanationType : 'model' or 'phenomenon'
|
|
||||||
eixgnn_cfg.explanation_type = "model"
|
|
||||||
|
|
||||||
eixgnn_cfg.model_config = CN()
|
|
||||||
|
|
||||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
|
||||||
eixgnn_cfg.model_config.mode = None
|
|
||||||
|
|
||||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
|
||||||
eixgnn_cfg.model_config.task_level = None
|
|
||||||
|
|
||||||
# Do not modify it, we always assume here that model output are 'raw'
|
|
||||||
eixgnn_cfg.model_config.return_type = "raw"
|
|
||||||
|
|
||||||
eixgnn_cfg.threshold_config = CN()
|
|
||||||
|
|
||||||
eixgnn_cfg.threshold_config.threshold_type = None
|
|
||||||
|
|
||||||
eixgnn_cfg.threshold_config.value = 0.5
|
|
||||||
|
|
||||||
# Set print destination: stdout / file / both
|
|
||||||
eixgnn_cfg.print = "both"
|
|
||||||
|
|
||||||
# Select device: 'cpu', 'cuda', 'auto'
|
|
||||||
eixgnn_cfg.accelerator = "auto"
|
|
||||||
|
|
||||||
# Output directory
|
|
||||||
eixgnn_cfg.out_dir = "results"
|
|
||||||
|
|
||||||
# Config name (in out_dir)
|
|
||||||
eixgnn_cfg.eixgnn_cfg_dest = "config.yaml"
|
|
||||||
|
|
||||||
# Random seed
|
|
||||||
eixgnn_cfg.seed = 0
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Globally shared variables:
|
|
||||||
# These variables will be set dynamically based on the input dataset
|
|
||||||
# Do not directly set them here or in .yaml files
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
eixgnn_cfg.share = CN()
|
|
||||||
|
|
||||||
# Size of input dimension
|
|
||||||
eixgnn_cfg.share.dim_in = 1
|
|
||||||
|
|
||||||
# Size of out dimension, i.e., number of labels to be predicted
|
|
||||||
eixgnn_cfg.share.dim_out = 1
|
|
||||||
|
|
||||||
# Number of dataset splits: train/val/test
|
|
||||||
eixgnn_cfg.share.num_splits = 1
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Dataset options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
eixgnn_cfg.dataset = CN()
|
|
||||||
|
|
||||||
# Name of the dataset
|
|
||||||
eixgnn_cfg.dataset.name = "Cora"
|
|
||||||
|
|
||||||
# if PyG: look for it in Pytorch Geometric dataset
|
|
||||||
# if NetworkX/nx: load data in NetworkX format
|
|
||||||
|
|
||||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
|
||||||
# cache dir
|
|
||||||
eixgnn_cfg.dataset.dir = "./datasets"
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Memory options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
eixgnn_cfg.mem = CN()
|
|
||||||
|
|
||||||
# Perform ReLU inplace
|
|
||||||
eixgnn_cfg.mem.inplace = False
|
|
||||||
|
|
||||||
# Set user customized eixgnn_cfgs
|
|
||||||
for func in register.config_dict.values():
|
|
||||||
func(eixgnn_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_eixgnn_cfg(eixgnn_cfg):
|
def assert_eixgnn_cfg(eixgnn_cfg):
|
||||||
r"""Checks config values, do necessary post processing to the configs"""
|
r"""Checks config values, do necessary post processing to the configs"""
|
||||||
if eixgnn_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
if not (0 <= eixgnn_cfg.p and eixgnn_cfg.p <= 1):
|
||||||
raise ValueError(
|
raise ValueError("p needs to be between 0 and 1")
|
||||||
"Task {} not supported, must be one of node, "
|
|
||||||
"edge, graph, link_pred".format(eixgnn_cfg.dataset.task)
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"classification" in eixgnn_cfg.dataset.task_type
|
|
||||||
and eixgnn_cfg.model.loss_fun == "mse"
|
|
||||||
):
|
|
||||||
eixgnn_cfg.model.loss_fun = "cross_entropy"
|
|
||||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
|
||||||
if (
|
|
||||||
eixgnn_cfg.dataset.task_type == "regression"
|
|
||||||
and eixgnn_cfg.model.loss_fun == "cross_entropy"
|
|
||||||
):
|
|
||||||
eixgnn_cfg.model.loss_fun = "mse"
|
|
||||||
logging.warning("model.loss_fun changed to mse for regression.")
|
|
||||||
if eixgnn_cfg.dataset.task == "graph" and eixgnn_cfg.dataset.transductive:
|
|
||||||
eixgnn_cfg.dataset.transductive = False
|
|
||||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
|
||||||
if eixgnn_cfg.gnn.layers_post_mp < 1:
|
|
||||||
eixgnn_cfg.gnn.layers_post_mp = 1
|
|
||||||
logging.warning("Layers after message passing should be >=1")
|
|
||||||
if eixgnn_cfg.gnn.head == "default":
|
|
||||||
eixgnn_cfg.gnn.head = eixgnn_cfg.dataset.task
|
|
||||||
eixgnn_cfg.run_dir = eixgnn_cfg.out_dir
|
|
||||||
|
|
||||||
|
|
||||||
def dump_eixgnn_cfg(eixgnn_cfg):
|
|
||||||
r"""
|
|
||||||
Dumps the config to the output directory specified in
|
|
||||||
:obj:`eixgnn_cfg.out_dir`
|
|
||||||
Args:
|
|
||||||
eixgnn_cfg (CfgNode): Configuration node
|
|
||||||
"""
|
|
||||||
makedirs(eixgnn_cfg.out_dir)
|
|
||||||
eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest)
|
|
||||||
with open(eixgnn_cfg_file, "w") as f:
|
|
||||||
eixgnn_cfg.dump(stream=f)
|
|
||||||
|
|
||||||
|
|
||||||
def load_eixgnn_cfg(eixgnn_cfg, args):
|
|
||||||
r"""
|
|
||||||
Load configurations from file system and command line
|
|
||||||
Args:
|
|
||||||
eixgnn_cfg (CfgNode): Configuration node
|
|
||||||
args (ArgumentParser): Command argument parser
|
|
||||||
"""
|
|
||||||
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
|
||||||
eixgnn_cfg.merge_from_list(args.opts)
|
|
||||||
assert_eixgnn_cfg(eixgnn_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def makedirs_rm_exist(dir):
|
|
||||||
if os.path.isdir(dir):
|
|
||||||
shutil.rmtree(dir)
|
|
||||||
os.makedirs(dir, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_fname(fname):
|
|
||||||
r"""
|
|
||||||
Extract filename from file name path
|
|
||||||
Args:
|
|
||||||
fname (string): Filename for the yaml format configuration file
|
|
||||||
"""
|
|
||||||
fname = fname.split("/")[-1]
|
|
||||||
if fname.endswith(".yaml"):
|
|
||||||
fname = fname[:-5]
|
|
||||||
elif fname.endswith(".yml"):
|
|
||||||
fname = fname[:-4]
|
|
||||||
return fname
|
|
||||||
|
|
||||||
|
|
||||||
def set_out_dir(out_dir, fname):
|
|
||||||
r"""
|
|
||||||
Create the directory for full experiment run
|
|
||||||
Args:
|
|
||||||
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
|
||||||
fname (string): Filename for the yaml format configuration file
|
|
||||||
"""
|
|
||||||
fname = get_fname(fname)
|
|
||||||
eixgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
|
||||||
# Make output directory
|
|
||||||
if eixgnn_cfg.train.auto_resume:
|
|
||||||
os.makedirs(eixgnn_cfg.out_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
makedirs_rm_exist(eixgnn_cfg.out_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def set_run_dir(out_dir):
|
|
||||||
r"""
|
|
||||||
Create the directory for each random seed experiment run
|
|
||||||
Args:
|
|
||||||
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
|
||||||
fname (string): Filename for the yaml format configuration file
|
|
||||||
"""
|
|
||||||
eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed))
|
|
||||||
# Make output directory
|
|
||||||
if eixgnn_cfg.train.auto_resume:
|
|
||||||
os.makedirs(eixgnn_cfg.run_dir, exist_ok=True)
|
|
||||||
else:
|
|
||||||
makedirs_rm_exist(eixgnn_cfg.run_dir)
|
|
||||||
|
|
||||||
|
|
||||||
set_eixgnn_cfg(eixgnn_cfg)
|
set_eixgnn_cfg(eixgnn_cfg)
|
||||||
|
|
||||||
|
# def dump_eixgnn_cfg(eixgnn_cfg,path):
|
||||||
def from_config(func):
|
# r"""
|
||||||
if inspect.isclass(func):
|
# Dumps the config to the output directory specified in
|
||||||
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
# :obj:`eixgnn_cfg.out_dir`
|
||||||
else:
|
# Args:
|
||||||
params = list(inspect.signature(func).parameters.values())
|
# eixgnn_cfg (CfgNode): Configuration node
|
||||||
|
# """
|
||||||
arg_names = [p.name for p in params]
|
# makedirs(eixgnn_cfg.out_dir)
|
||||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
# eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest)
|
||||||
|
# with open(eixgnn_cfg_file, "w") as f:
|
||||||
@functools.wraps(func)
|
# eixgnn_cfg.dump(stream=f)
|
||||||
def wrapper(*args, eixgnn_cfg: Any = None, **kwargs):
|
#
|
||||||
if eixgnn_cfg is not None:
|
#
|
||||||
eixgnn_cfg = (
|
# def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||||
dict(eixgnn_cfg)
|
# r"""
|
||||||
if isinstance(eixgnn_cfg, Iterable)
|
# Load configurations from file system and command line
|
||||||
else asdict(eixgnn_cfg)
|
# Args:
|
||||||
)
|
# eixgnn_cfg (CfgNode): Configuration node
|
||||||
|
# args (ArgumentParser): Command argument parser
|
||||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
# """
|
||||||
for arg_name, has_default in iterator:
|
# eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||||
if arg_name in kwargs:
|
# eixgnn_cfg.merge_from_list(args.opts)
|
||||||
continue
|
# assert_eixgnn_cfg(eixgnn_cfg)
|
||||||
elif arg_name in eixgnn_cfg:
|
#
|
||||||
kwargs[arg_name] = eixgnn_cfg[arg_name]
|
#
|
||||||
elif not has_default:
|
# def makedirs_rm_exist(dir):
|
||||||
raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined")
|
# if os.path.isdir(dir):
|
||||||
return func(*args, **kwargs)
|
# shutil.rmtree(dir)
|
||||||
|
# os.makedirs(dir, exist_ok=True)
|
||||||
return wrapper
|
#
|
||||||
|
#
|
||||||
|
# def get_fname(fname):
|
||||||
|
# r"""
|
||||||
|
# Extract filename from file name path
|
||||||
|
# Args:
|
||||||
|
# fname (string): Filename for the yaml format configuration file
|
||||||
|
# """
|
||||||
|
# fname = fname.split("/")[-1]
|
||||||
|
# if fname.endswith(".yaml"):
|
||||||
|
# fname = fname[:-5]
|
||||||
|
# elif fname.endswith(".yml"):
|
||||||
|
# fname = fname[:-4]
|
||||||
|
# return fname
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def set_out_dir(out_dir, fname):
|
||||||
|
# r"""
|
||||||
|
# Create the directory for full experiment run
|
||||||
|
# Args:
|
||||||
|
# out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||||
|
# fname (string): Filename for the yaml format configuration file
|
||||||
|
# """
|
||||||
|
# fname = get_fname(fname)
|
||||||
|
# eixgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
||||||
|
# Make output directory
|
||||||
|
# if eixgnn_cfg.train.auto_resume:
|
||||||
|
# os.makedirs(eixgnn_cfg.out_dir, exist_ok=True)
|
||||||
|
# else:
|
||||||
|
# makedirs_rm_exist(eixgnn_cfg.out_dir)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def set_run_dir(out_dir):
|
||||||
|
# r"""
|
||||||
|
# Create the directory for each random seed experiment run
|
||||||
|
# Args:
|
||||||
|
# out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||||
|
# fname (string): Filename for the yaml format configuration file
|
||||||
|
# """
|
||||||
|
# eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed))
|
||||||
|
# Make output directory
|
||||||
|
# if eixgnn_cfg.train.auto_resume:
|
||||||
|
# os.makedirs(eixgnn_cfg.run_dir, exist_ok=True)
|
||||||
|
# else:
|
||||||
|
# makedirs_rm_exist(eixgnn_cfg.run_dir)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def from_config(func):
|
||||||
|
# if inspect.isclass(func):
|
||||||
|
# params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||||
|
# else:
|
||||||
|
# params = list(inspect.signature(func).parameters.values())
|
||||||
|
#
|
||||||
|
# arg_names = [p.name for p in params]
|
||||||
|
# has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||||
|
#
|
||||||
|
# @functools.wraps(func)
|
||||||
|
# def wrapper(*args, eixgnn_cfg: Any = None, **kwargs):
|
||||||
|
# if eixgnn_cfg is not None:
|
||||||
|
# eixgnn_cfg = (
|
||||||
|
# dict(eixgnn_cfg)
|
||||||
|
# if isinstance(eixgnn_cfg, Iterable)
|
||||||
|
# else asdict(eixgnn_cfg)
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||||
|
# for arg_name, has_default in iterator:
|
||||||
|
# if arg_name in kwargs:
|
||||||
|
# continue
|
||||||
|
# elif arg_name in eixgnn_cfg:
|
||||||
|
# kwargs[arg_name] = eixgnn_cfg[arg_name]
|
||||||
|
# elif not has_default:
|
||||||
|
# raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined")
|
||||||
|
# return func(*args, **kwargs)
|
||||||
|
#
|
||||||
|
# return wrapper
|
||||||
|
|
|
@ -14,9 +14,9 @@ from torch_geometric.data.makedirs import makedirs
|
||||||
try: # Define global config object
|
try: # Define global config object
|
||||||
from yacs.config import CfgNode as CN
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
explaining_cfg = CN()
|
scgnn_cfg = CN()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
explaining_cfg = None
|
scgnn_cfg = None
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Could not define global config object. Please install "
|
"Could not define global config object. Please install "
|
||||||
"'yacs' for using the GraphGym experiment manager via "
|
"'yacs' for using the GraphGym experiment manager via "
|
||||||
|
@ -24,199 +24,64 @@ except ImportError:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_explaining_cfg(explaining_cfg):
|
def set_scgnn_cfg(scgnn_cfg):
|
||||||
r"""
|
r"""
|
||||||
This function sets the default config value.
|
This function sets the default config value.
|
||||||
1) Note that for an experiment, only part of the arguments will be used
|
1) Note that for an experiment, only part of the arguments will be used
|
||||||
The remaining unused arguments won't affect anything.
|
The remaining unused arguments won't affect anything.
|
||||||
So feel free to register any argument in graphgym.contrib.config
|
So feel free to register any argument in graphgym.contrib.config
|
||||||
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
|
2) We support *at most* two levels of configs, e.g., scgnn_cfg.dataset.name
|
||||||
:return: configuration use by the experiment.
|
:return: configuration use by the experiment.
|
||||||
"""
|
"""
|
||||||
if explaining_cfg is None:
|
if scgnn_cfg is None:
|
||||||
return explaining_cfg
|
return scgnn_cfg
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
scgnn_cfg.depth = 'all'
|
||||||
# Basic options
|
scgnn_cfg.interest_map_norm = True
|
||||||
# ----------------------------------------------------------------------- #
|
scgnn_cfg.score_map_norm = True
|
||||||
|
|
||||||
# Set print destination: stdout / file / both
|
|
||||||
explaining_cfg.print = "both"
|
|
||||||
|
|
||||||
explaining_cfg.out_dir = "./explanations"
|
|
||||||
|
|
||||||
explaining_cfg.cfg_dest = "explaining_config.yaml"
|
|
||||||
|
|
||||||
explaining_cfg.seed = 0
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Dataset options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
explaining_cfg.dataset = CN()
|
|
||||||
|
|
||||||
explaining_cfg.dataset.name = "Cora"
|
|
||||||
|
|
||||||
explaining_cfg.run_topological_stat = True
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Model options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
explaining_cfg.model = CN()
|
|
||||||
|
|
||||||
# Set wether or not load the best model for given dataset or a path
|
|
||||||
explaining_cfg.model.ckpt = "best"
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Explainer options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
explaining_cfg.explainer = CN()
|
|
||||||
|
|
||||||
# Name of the explaining method
|
|
||||||
explaining_cfg.explainer.name = "EiXGNN"
|
|
||||||
|
|
||||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
|
||||||
explaining_cfg.explainer.cfg = "default"
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Explaining options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
# 'ExplanationType : 'model' or 'phenomenon'
|
|
||||||
explaining_cfg.explanation_type = "model"
|
|
||||||
|
|
||||||
explaining_cfg.model_config = CN()
|
|
||||||
|
|
||||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
|
||||||
explaining_cfg.model_config.mode = None
|
|
||||||
|
|
||||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
|
||||||
explaining_cfg.model_config.task_level = None
|
|
||||||
|
|
||||||
# Do not modify it, we always assume here that model output are 'raw'
|
|
||||||
explaining_cfg.model_config.return_type = "raw"
|
|
||||||
|
|
||||||
explaining_cfg.threshold_config = CN()
|
|
||||||
|
|
||||||
explaining_cfg.threshold_config.threshold_type = None
|
|
||||||
|
|
||||||
explaining_cfg.threshold_config.value = 0.5
|
|
||||||
|
|
||||||
# Set print destination: stdout / file / both
|
|
||||||
explaining_cfg.print = "both"
|
|
||||||
|
|
||||||
# Select device: 'cpu', 'cuda', 'auto'
|
|
||||||
explaining_cfg.accelerator = "auto"
|
|
||||||
|
|
||||||
# Output directory
|
|
||||||
explaining_cfg.out_dir = "results"
|
|
||||||
|
|
||||||
# Config name (in out_dir)
|
|
||||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
|
||||||
|
|
||||||
# Random seed
|
|
||||||
explaining_cfg.seed = 0
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Globally shared variables:
|
|
||||||
# These variables will be set dynamically based on the input dataset
|
|
||||||
# Do not directly set them here or in .yaml files
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
explaining_cfg.share = CN()
|
|
||||||
|
|
||||||
# Size of input dimension
|
|
||||||
explaining_cfg.share.dim_in = 1
|
|
||||||
|
|
||||||
# Size of out dimension, i.e., number of labels to be predicted
|
|
||||||
explaining_cfg.share.dim_out = 1
|
|
||||||
|
|
||||||
# Number of dataset splits: train/val/test
|
|
||||||
explaining_cfg.share.num_splits = 1
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Dataset options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
explaining_cfg.dataset = CN()
|
|
||||||
|
|
||||||
# Name of the dataset
|
|
||||||
explaining_cfg.dataset.name = "Cora"
|
|
||||||
|
|
||||||
# if PyG: look for it in Pytorch Geometric dataset
|
|
||||||
# if NetworkX/nx: load data in NetworkX format
|
|
||||||
|
|
||||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
|
||||||
# cache dir
|
|
||||||
explaining_cfg.dataset.dir = "./datasets"
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Memory options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
explaining_cfg.mem = CN()
|
|
||||||
|
|
||||||
# Perform ReLU inplace
|
|
||||||
explaining_cfg.mem.inplace = False
|
|
||||||
|
|
||||||
# Set user customized explaining_cfgs
|
|
||||||
for func in register.config_dict.values():
|
|
||||||
func(explaining_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_explaining_cfg(explaining_cfg):
|
|
||||||
r"""Checks config values, do necessary post processing to the configs"""
|
|
||||||
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
def assert_scgnn_cfg(scgnn_cfg):
|
||||||
|
r"""Checks config values, do necessary post processing to the configs
|
||||||
|
TODO
|
||||||
|
|
||||||
|
"""
|
||||||
|
if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Task {} not supported, must be one of node, "
|
"Task {} not supported, must be one of node, "
|
||||||
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
|
"edge, graph, link_pred".format(scgnn_cfg.dataset.task)
|
||||||
)
|
)
|
||||||
if (
|
scgnn_cfg.run_dir = scgnn_cfg.out_dir
|
||||||
"classification" in explaining_cfg.dataset.task_type
|
|
||||||
and explaining_cfg.model.loss_fun == "mse"
|
|
||||||
):
|
|
||||||
explaining_cfg.model.loss_fun = "cross_entropy"
|
|
||||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
|
||||||
if (
|
|
||||||
explaining_cfg.dataset.task_type == "regression"
|
|
||||||
and explaining_cfg.model.loss_fun == "cross_entropy"
|
|
||||||
):
|
|
||||||
explaining_cfg.model.loss_fun = "mse"
|
|
||||||
logging.warning("model.loss_fun changed to mse for regression.")
|
|
||||||
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
|
|
||||||
explaining_cfg.dataset.transductive = False
|
|
||||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
|
||||||
if explaining_cfg.gnn.layers_post_mp < 1:
|
|
||||||
explaining_cfg.gnn.layers_post_mp = 1
|
|
||||||
logging.warning("Layers after message passing should be >=1")
|
|
||||||
if explaining_cfg.gnn.head == "default":
|
|
||||||
explaining_cfg.gnn.head = explaining_cfg.dataset.task
|
|
||||||
explaining_cfg.run_dir = explaining_cfg.out_dir
|
|
||||||
|
|
||||||
|
|
||||||
def dump_explaining_cfg(explaining_cfg):
|
def dump_scgnn_cfg(scgnn_cfg,path):
|
||||||
r"""
|
r"""
|
||||||
|
TODO
|
||||||
Dumps the config to the output directory specified in
|
Dumps the config to the output directory specified in
|
||||||
:obj:`explaining_cfg.out_dir`
|
:obj:`scgnn_cfg.out_dir`
|
||||||
Args:
|
Args:
|
||||||
explaining_cfg (CfgNode): Configuration node
|
scgnn_cfg (CfgNode): Configuration node
|
||||||
"""
|
"""
|
||||||
makedirs(explaining_cfg.out_dir)
|
makedirs(scgnn_cfg.out_dir)
|
||||||
explaining_cfg_file = os.path.join(
|
scgnn_cfg_file = os.path.join(
|
||||||
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
|
scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest
|
||||||
)
|
)
|
||||||
with open(explaining_cfg_file, "w") as f:
|
with open(scgnn_cfg_file, "w") as f:
|
||||||
explaining_cfg.dump(stream=f)
|
scgnn_cfg.dump(stream=f)
|
||||||
|
|
||||||
|
|
||||||
def load_explaining_cfg(explaining_cfg, args):
|
def load_scgnn_cfg(scgnn_cfg, args):
|
||||||
r"""
|
r"""
|
||||||
Load configurations from file system and command line
|
Load configurations from file system and command line
|
||||||
Args:
|
Args:
|
||||||
explaining_cfg (CfgNode): Configuration node
|
scgnn_cfg (CfgNode): Configuration node
|
||||||
args (ArgumentParser): Command argument parser
|
args (ArgumentParser): Command argument parser
|
||||||
"""
|
"""
|
||||||
explaining_cfg.merge_from_file(args.explaining_cfg_file)
|
scgnn_cfg.merge_from_file(args.scgnn_cfg_file)
|
||||||
explaining_cfg.merge_from_list(args.opts)
|
scgnn_cfg.merge_from_list(args.opts)
|
||||||
assert_explaining_cfg(explaining_cfg)
|
assert_scgnn_cfg(scgnn_cfg)
|
||||||
|
|
||||||
|
|
||||||
def makedirs_rm_exist(dir):
|
def makedirs_rm_exist(dir):
|
||||||
|
@ -243,34 +108,34 @@ def set_out_dir(out_dir, fname):
|
||||||
r"""
|
r"""
|
||||||
Create the directory for full experiment run
|
Create the directory for full experiment run
|
||||||
Args:
|
Args:
|
||||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
out_dir (string): Directory for output, specified in :obj:`scgnn_cfg.out_dir`
|
||||||
fname (string): Filename for the yaml format configuration file
|
fname (string): Filename for the yaml format configuration file
|
||||||
"""
|
"""
|
||||||
fname = get_fname(fname)
|
fname = get_fname(fname)
|
||||||
explaining_cfg.out_dir = os.path.join(out_dir, fname)
|
scgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
||||||
# Make output directory
|
# Make output directory
|
||||||
if explaining_cfg.train.auto_resume:
|
if scgnn_cfg.train.auto_resume:
|
||||||
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
os.makedirs(scgnn_cfg.out_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
makedirs_rm_exist(explaining_cfg.out_dir)
|
makedirs_rm_exist(scgnn_cfg.out_dir)
|
||||||
|
|
||||||
|
|
||||||
def set_run_dir(out_dir):
|
def set_run_dir(out_dir):
|
||||||
r"""
|
r"""
|
||||||
Create the directory for each random seed experiment run
|
Create the directory for each random seed experiment run
|
||||||
Args:
|
Args:
|
||||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
out_dir (string): Directory for output, specified in :obj:`scgnn_cfg.out_dir`
|
||||||
fname (string): Filename for the yaml format configuration file
|
fname (string): Filename for the yaml format configuration file
|
||||||
"""
|
"""
|
||||||
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
|
scgnn_cfg.run_dir = os.path.join(out_dir, str(scgnn_cfg.seed))
|
||||||
# Make output directory
|
# Make output directory
|
||||||
if explaining_cfg.train.auto_resume:
|
if scgnn_cfg.train.auto_resume:
|
||||||
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
os.makedirs(scgnn_cfg.run_dir, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
makedirs_rm_exist(scgnn_cfg.run_dir)
|
||||||
|
|
||||||
|
|
||||||
set_explaining_cfg(explaining_cfg)
|
set_scgnn_cfg(scgnn_cfg)
|
||||||
|
|
||||||
|
|
||||||
def from_config(func):
|
def from_config(func):
|
||||||
|
@ -283,22 +148,22 @@ def from_config(func):
|
||||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
|
def wrapper(*args, scgnn_cfg: Any = None, **kwargs):
|
||||||
if explaining_cfg is not None:
|
if scgnn_cfg is not None:
|
||||||
explaining_cfg = (
|
scgnn_cfg = (
|
||||||
dict(explaining_cfg)
|
dict(scgnn_cfg)
|
||||||
if isinstance(explaining_cfg, Iterable)
|
if isinstance(scgnn_cfg, Iterable)
|
||||||
else asdict(explaining_cfg)
|
else asdict(scgnn_cfg)
|
||||||
)
|
)
|
||||||
|
|
||||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||||
for arg_name, has_default in iterator:
|
for arg_name, has_default in iterator:
|
||||||
if arg_name in kwargs:
|
if arg_name in kwargs:
|
||||||
continue
|
continue
|
||||||
elif arg_name in explaining_cfg:
|
elif arg_name in scgnn_cfg:
|
||||||
kwargs[arg_name] = explaining_cfg[arg_name]
|
kwargs[arg_name] = scgnn_cfg[arg_name]
|
||||||
elif not has_default:
|
elif not has_default:
|
||||||
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
|
raise ValueError(f"'scgnn_cfg.{arg_name}' undefined")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
@ -24,7 +24,7 @@ except ImportError:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_explaining_cfg(explaining_cfg):
|
def set_cfg(explaining_cfg):
|
||||||
r"""
|
r"""
|
||||||
This function sets the default config value.
|
This function sets the default config value.
|
||||||
1) Note that for an experiment, only part of the arguments will be used
|
1) Note that for an experiment, only part of the arguments will be used
|
||||||
|
@ -110,89 +110,23 @@ def set_explaining_cfg(explaining_cfg):
|
||||||
# Select device: 'cpu', 'cuda', 'auto'
|
# Select device: 'cpu', 'cuda', 'auto'
|
||||||
explaining_cfg.accelerator = "auto"
|
explaining_cfg.accelerator = "auto"
|
||||||
|
|
||||||
# Output directory
|
|
||||||
explaining_cfg.out_dir = "results"
|
|
||||||
|
|
||||||
# Config name (in out_dir)
|
# Config name (in out_dir)
|
||||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||||
|
|
||||||
# Random seed
|
|
||||||
explaining_cfg.seed = 0
|
explaining_cfg.seed = 0
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Globally shared variables:
|
|
||||||
# These variables will be set dynamically based on the input dataset
|
|
||||||
# Do not directly set them here or in .yaml files
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
|
|
||||||
explaining_cfg.share = CN()
|
|
||||||
|
|
||||||
# Size of input dimension
|
|
||||||
explaining_cfg.share.dim_in = 1
|
|
||||||
|
|
||||||
# Size of out dimension, i.e., number of labels to be predicted
|
|
||||||
explaining_cfg.share.dim_out = 1
|
|
||||||
|
|
||||||
# Number of dataset splits: train/val/test
|
|
||||||
explaining_cfg.share.num_splits = 1
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Dataset options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
explaining_cfg.dataset = CN()
|
|
||||||
|
|
||||||
# Name of the dataset
|
|
||||||
explaining_cfg.dataset.name = "Cora"
|
|
||||||
|
|
||||||
# if PyG: look for it in Pytorch Geometric dataset
|
|
||||||
# if NetworkX/nx: load data in NetworkX format
|
|
||||||
|
|
||||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
|
||||||
# cache dir
|
|
||||||
explaining_cfg.dataset.dir = "./datasets"
|
explaining_cfg.dataset.dir = "./datasets"
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# Memory options
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
explaining_cfg.mem = CN()
|
|
||||||
|
|
||||||
# Perform ReLU inplace
|
explaining_cfg.relu_and_normalize = True
|
||||||
explaining_cfg.mem.inplace = False
|
|
||||||
|
|
||||||
# Set user customized explaining_cfgs
|
|
||||||
for func in register.config_dict.values():
|
|
||||||
func(explaining_cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_explaining_cfg(explaining_cfg):
|
|
||||||
|
def assert_cfg(explaining_cfg):
|
||||||
r"""Checks config values, do necessary post processing to the configs"""
|
r"""Checks config values, do necessary post processing to the configs"""
|
||||||
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
|
||||||
raise ValueError(
|
|
||||||
"Task {} not supported, must be one of node, "
|
|
||||||
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"classification" in explaining_cfg.dataset.task_type
|
|
||||||
and explaining_cfg.model.loss_fun == "mse"
|
|
||||||
):
|
|
||||||
explaining_cfg.model.loss_fun = "cross_entropy"
|
|
||||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
|
||||||
if (
|
|
||||||
explaining_cfg.dataset.task_type == "regression"
|
|
||||||
and explaining_cfg.model.loss_fun == "cross_entropy"
|
|
||||||
):
|
|
||||||
explaining_cfg.model.loss_fun = "mse"
|
|
||||||
logging.warning("model.loss_fun changed to mse for regression.")
|
|
||||||
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
|
|
||||||
explaining_cfg.dataset.transductive = False
|
|
||||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
|
||||||
if explaining_cfg.gnn.layers_post_mp < 1:
|
|
||||||
explaining_cfg.gnn.layers_post_mp = 1
|
|
||||||
logging.warning("Layers after message passing should be >=1")
|
|
||||||
if explaining_cfg.gnn.head == "default":
|
|
||||||
explaining_cfg.gnn.head = explaining_cfg.dataset.task
|
|
||||||
explaining_cfg.run_dir = explaining_cfg.out_dir
|
explaining_cfg.run_dir = explaining_cfg.out_dir
|
||||||
|
|
||||||
|
|
||||||
def dump_explaining_cfg(explaining_cfg):
|
def dump_cfg(explaining_cfg):
|
||||||
r"""
|
r"""
|
||||||
Dumps the config to the output directory specified in
|
Dumps the config to the output directory specified in
|
||||||
:obj:`explaining_cfg.out_dir`
|
:obj:`explaining_cfg.out_dir`
|
||||||
|
@ -207,7 +141,7 @@ def dump_explaining_cfg(explaining_cfg):
|
||||||
explaining_cfg.dump(stream=f)
|
explaining_cfg.dump(stream=f)
|
||||||
|
|
||||||
|
|
||||||
def load_explaining_cfg(explaining_cfg, args):
|
def load_cfg(explaining_cfg, args):
|
||||||
r"""
|
r"""
|
||||||
Load configurations from file system and command line
|
Load configurations from file system and command line
|
||||||
Args:
|
Args:
|
||||||
|
@ -270,7 +204,7 @@ def set_run_dir(out_dir):
|
||||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||||
|
|
||||||
|
|
||||||
set_explaining_cfg(explaining_cfg)
|
set_cfg(explaining_cfg)
|
||||||
|
|
||||||
|
|
||||||
def from_config(func):
|
def from_config(func):
|
||||||
|
|
41
explaining_framework/utils/explaining/load_model.py
Normal file
41
explaining_framework/utils/explaining/load_model.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
|
|
||||||
|
MODEL_STATE = "model_state"
|
||||||
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
|
||||||
|
|
||||||
|
def load_ckpt(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
ckpt_path: str,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
r"""Loads the model at given checkpoint."""
|
||||||
|
|
||||||
|
if not osp.exists(path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
ckpt = torch.load(ckpt_path)
|
||||||
|
model.load_state_dict(ckpt[MODEL_STATE])
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_best_given_exp(path_to_xp:str, wrt_metric:str:'val')->str:
|
||||||
|
path = os.path.normpath(path)
|
||||||
|
path.split(os.sep)
|
||||||
|
for path in glob.glob(os.path.join(path_to_xp,'[0-9]'*10,wrt_metric,'stats.json')):
|
||||||
|
print(path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ def load_explanation(path: str) -> Explanation:
|
||||||
|
|
||||||
|
|
||||||
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
|
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
|
||||||
|
exp = copy.copy(exp)
|
||||||
data = exp.to_dict()
|
data = exp.to_dict()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||||
|
|
Loading…
Add table
Reference in a new issue