New features
This commit is contained in:
parent
7f540f53d7
commit
495cde4c70
@ -36,267 +36,125 @@ def set_eixgnn_cfg(eixgnn_cfg):
|
||||
if eixgnn_cfg is None:
|
||||
return eixgnn_cfg
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Basic options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
eixgnn_cfg.print = "both"
|
||||
|
||||
eixgnn_cfg.out_dir = "./explanations"
|
||||
|
||||
eixgnn_cfg.cfg_dest = "explaining_config.yaml"
|
||||
|
||||
eixgnn_cfg.seed = 0
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.dataset = CN()
|
||||
|
||||
eixgnn_cfg.dataset.name = "Cora"
|
||||
|
||||
eixgnn_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Model options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.model = CN()
|
||||
|
||||
# Set wether or not load the best model for given dataset or a path
|
||||
eixgnn_cfg.model.ckpt = "best"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explainer options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.explainer = CN()
|
||||
|
||||
# Name of the explaining method
|
||||
eixgnn_cfg.explainer.name = "EiXGNN"
|
||||
|
||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
eixgnn_cfg.explainer.cfg = "default"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# 'ExplanationType : 'model' or 'phenomenon'
|
||||
eixgnn_cfg.explanation_type = "model"
|
||||
|
||||
eixgnn_cfg.model_config = CN()
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
eixgnn_cfg.model_config.mode = None
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
eixgnn_cfg.model_config.task_level = None
|
||||
|
||||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
eixgnn_cfg.model_config.return_type = "raw"
|
||||
|
||||
eixgnn_cfg.threshold_config = CN()
|
||||
|
||||
eixgnn_cfg.threshold_config.threshold_type = None
|
||||
|
||||
eixgnn_cfg.threshold_config.value = 0.5
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
eixgnn_cfg.print = "both"
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
eixgnn_cfg.accelerator = "auto"
|
||||
|
||||
# Output directory
|
||||
eixgnn_cfg.out_dir = "results"
|
||||
|
||||
# Config name (in out_dir)
|
||||
eixgnn_cfg.eixgnn_cfg_dest = "config.yaml"
|
||||
|
||||
# Random seed
|
||||
eixgnn_cfg.seed = 0
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Globally shared variables:
|
||||
# These variables will be set dynamically based on the input dataset
|
||||
# Do not directly set them here or in .yaml files
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
eixgnn_cfg.share = CN()
|
||||
|
||||
# Size of input dimension
|
||||
eixgnn_cfg.share.dim_in = 1
|
||||
|
||||
# Size of out dimension, i.e., number of labels to be predicted
|
||||
eixgnn_cfg.share.dim_out = 1
|
||||
|
||||
# Number of dataset splits: train/val/test
|
||||
eixgnn_cfg.share.num_splits = 1
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
eixgnn_cfg.dataset = CN()
|
||||
|
||||
# Name of the dataset
|
||||
eixgnn_cfg.dataset.name = "Cora"
|
||||
|
||||
# if PyG: look for it in Pytorch Geometric dataset
|
||||
# if NetworkX/nx: load data in NetworkX format
|
||||
|
||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
||||
# cache dir
|
||||
eixgnn_cfg.dataset.dir = "./datasets"
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Memory options
|
||||
# ----------------------------------------------------------------------- #
|
||||
eixgnn_cfg.mem = CN()
|
||||
|
||||
# Perform ReLU inplace
|
||||
eixgnn_cfg.mem.inplace = False
|
||||
|
||||
# Set user customized eixgnn_cfgs
|
||||
for func in register.config_dict.values():
|
||||
func(eixgnn_cfg)
|
||||
eixgnn_cfg.L = 50
|
||||
eixgnn_cfg.p = 0.5
|
||||
eixgnn_cfg.importance_sampling_strategy = "node"
|
||||
eixgnn_cfg.domain_similarity = "relative_edge_density"
|
||||
eixgnn_cfg.signal_similarity = "KL"
|
||||
eixgnn_cfg.shapley_value_approx = 100
|
||||
|
||||
|
||||
def assert_eixgnn_cfg(eixgnn_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
if eixgnn_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(eixgnn_cfg.dataset.task)
|
||||
)
|
||||
if (
|
||||
"classification" in eixgnn_cfg.dataset.task_type
|
||||
and eixgnn_cfg.model.loss_fun == "mse"
|
||||
):
|
||||
eixgnn_cfg.model.loss_fun = "cross_entropy"
|
||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
||||
if (
|
||||
eixgnn_cfg.dataset.task_type == "regression"
|
||||
and eixgnn_cfg.model.loss_fun == "cross_entropy"
|
||||
):
|
||||
eixgnn_cfg.model.loss_fun = "mse"
|
||||
logging.warning("model.loss_fun changed to mse for regression.")
|
||||
if eixgnn_cfg.dataset.task == "graph" and eixgnn_cfg.dataset.transductive:
|
||||
eixgnn_cfg.dataset.transductive = False
|
||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
||||
if eixgnn_cfg.gnn.layers_post_mp < 1:
|
||||
eixgnn_cfg.gnn.layers_post_mp = 1
|
||||
logging.warning("Layers after message passing should be >=1")
|
||||
if eixgnn_cfg.gnn.head == "default":
|
||||
eixgnn_cfg.gnn.head = eixgnn_cfg.dataset.task
|
||||
eixgnn_cfg.run_dir = eixgnn_cfg.out_dir
|
||||
|
||||
|
||||
def dump_eixgnn_cfg(eixgnn_cfg):
|
||||
r"""
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`eixgnn_cfg.out_dir`
|
||||
Args:
|
||||
eixgnn_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(eixgnn_cfg.out_dir)
|
||||
eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest)
|
||||
with open(eixgnn_cfg_file, "w") as f:
|
||||
eixgnn_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
eixgnn_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||
eixgnn_cfg.merge_from_list(args.opts)
|
||||
assert_eixgnn_cfg(eixgnn_cfg)
|
||||
|
||||
|
||||
def makedirs_rm_exist(dir):
|
||||
if os.path.isdir(dir):
|
||||
shutil.rmtree(dir)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
|
||||
def get_fname(fname):
|
||||
r"""
|
||||
Extract filename from file name path
|
||||
Args:
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = fname.split("/")[-1]
|
||||
if fname.endswith(".yaml"):
|
||||
fname = fname[:-5]
|
||||
elif fname.endswith(".yml"):
|
||||
fname = fname[:-4]
|
||||
return fname
|
||||
|
||||
|
||||
def set_out_dir(out_dir, fname):
|
||||
r"""
|
||||
Create the directory for full experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = get_fname(fname)
|
||||
eixgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
if eixgnn_cfg.train.auto_resume:
|
||||
os.makedirs(eixgnn_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(eixgnn_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
r"""
|
||||
Create the directory for each random seed experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed))
|
||||
# Make output directory
|
||||
if eixgnn_cfg.train.auto_resume:
|
||||
os.makedirs(eixgnn_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(eixgnn_cfg.run_dir)
|
||||
if not (0 <= eixgnn_cfg.p and eixgnn_cfg.p <= 1):
|
||||
raise ValueError("p needs to be between 0 and 1")
|
||||
|
||||
|
||||
set_eixgnn_cfg(eixgnn_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
if inspect.isclass(func):
|
||||
params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||
else:
|
||||
params = list(inspect.signature(func).parameters.values())
|
||||
|
||||
arg_names = [p.name for p in params]
|
||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, eixgnn_cfg: Any = None, **kwargs):
|
||||
if eixgnn_cfg is not None:
|
||||
eixgnn_cfg = (
|
||||
dict(eixgnn_cfg)
|
||||
if isinstance(eixgnn_cfg, Iterable)
|
||||
else asdict(eixgnn_cfg)
|
||||
)
|
||||
|
||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
for arg_name, has_default in iterator:
|
||||
if arg_name in kwargs:
|
||||
continue
|
||||
elif arg_name in eixgnn_cfg:
|
||||
kwargs[arg_name] = eixgnn_cfg[arg_name]
|
||||
elif not has_default:
|
||||
raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
# def dump_eixgnn_cfg(eixgnn_cfg,path):
|
||||
# r"""
|
||||
# Dumps the config to the output directory specified in
|
||||
# :obj:`eixgnn_cfg.out_dir`
|
||||
# Args:
|
||||
# eixgnn_cfg (CfgNode): Configuration node
|
||||
# """
|
||||
# makedirs(eixgnn_cfg.out_dir)
|
||||
# eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest)
|
||||
# with open(eixgnn_cfg_file, "w") as f:
|
||||
# eixgnn_cfg.dump(stream=f)
|
||||
#
|
||||
#
|
||||
# def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||
# r"""
|
||||
# Load configurations from file system and command line
|
||||
# Args:
|
||||
# eixgnn_cfg (CfgNode): Configuration node
|
||||
# args (ArgumentParser): Command argument parser
|
||||
# """
|
||||
# eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||
# eixgnn_cfg.merge_from_list(args.opts)
|
||||
# assert_eixgnn_cfg(eixgnn_cfg)
|
||||
#
|
||||
#
|
||||
# def makedirs_rm_exist(dir):
|
||||
# if os.path.isdir(dir):
|
||||
# shutil.rmtree(dir)
|
||||
# os.makedirs(dir, exist_ok=True)
|
||||
#
|
||||
#
|
||||
# def get_fname(fname):
|
||||
# r"""
|
||||
# Extract filename from file name path
|
||||
# Args:
|
||||
# fname (string): Filename for the yaml format configuration file
|
||||
# """
|
||||
# fname = fname.split("/")[-1]
|
||||
# if fname.endswith(".yaml"):
|
||||
# fname = fname[:-5]
|
||||
# elif fname.endswith(".yml"):
|
||||
# fname = fname[:-4]
|
||||
# return fname
|
||||
#
|
||||
#
|
||||
# def set_out_dir(out_dir, fname):
|
||||
# r"""
|
||||
# Create the directory for full experiment run
|
||||
# Args:
|
||||
# out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||
# fname (string): Filename for the yaml format configuration file
|
||||
# """
|
||||
# fname = get_fname(fname)
|
||||
# eixgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
# if eixgnn_cfg.train.auto_resume:
|
||||
# os.makedirs(eixgnn_cfg.out_dir, exist_ok=True)
|
||||
# else:
|
||||
# makedirs_rm_exist(eixgnn_cfg.out_dir)
|
||||
#
|
||||
#
|
||||
# def set_run_dir(out_dir):
|
||||
# r"""
|
||||
# Create the directory for each random seed experiment run
|
||||
# Args:
|
||||
# out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir`
|
||||
# fname (string): Filename for the yaml format configuration file
|
||||
# """
|
||||
# eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed))
|
||||
# Make output directory
|
||||
# if eixgnn_cfg.train.auto_resume:
|
||||
# os.makedirs(eixgnn_cfg.run_dir, exist_ok=True)
|
||||
# else:
|
||||
# makedirs_rm_exist(eixgnn_cfg.run_dir)
|
||||
#
|
||||
#
|
||||
# def from_config(func):
|
||||
# if inspect.isclass(func):
|
||||
# params = list(inspect.signature(func.__init__).parameters.values())[1:]
|
||||
# else:
|
||||
# params = list(inspect.signature(func).parameters.values())
|
||||
#
|
||||
# arg_names = [p.name for p in params]
|
||||
# has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
#
|
||||
# @functools.wraps(func)
|
||||
# def wrapper(*args, eixgnn_cfg: Any = None, **kwargs):
|
||||
# if eixgnn_cfg is not None:
|
||||
# eixgnn_cfg = (
|
||||
# dict(eixgnn_cfg)
|
||||
# if isinstance(eixgnn_cfg, Iterable)
|
||||
# else asdict(eixgnn_cfg)
|
||||
# )
|
||||
#
|
||||
# iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
# for arg_name, has_default in iterator:
|
||||
# if arg_name in kwargs:
|
||||
# continue
|
||||
# elif arg_name in eixgnn_cfg:
|
||||
# kwargs[arg_name] = eixgnn_cfg[arg_name]
|
||||
# elif not has_default:
|
||||
# raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined")
|
||||
# return func(*args, **kwargs)
|
||||
#
|
||||
# return wrapper
|
||||
|
@ -14,9 +14,9 @@ from torch_geometric.data.makedirs import makedirs
|
||||
try: # Define global config object
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
explaining_cfg = CN()
|
||||
scgnn_cfg = CN()
|
||||
except ImportError:
|
||||
explaining_cfg = None
|
||||
scgnn_cfg = None
|
||||
warnings.warn(
|
||||
"Could not define global config object. Please install "
|
||||
"'yacs' for using the GraphGym experiment manager via "
|
||||
@ -24,199 +24,64 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
def set_explaining_cfg(explaining_cfg):
|
||||
def set_scgnn_cfg(scgnn_cfg):
|
||||
r"""
|
||||
This function sets the default config value.
|
||||
1) Note that for an experiment, only part of the arguments will be used
|
||||
The remaining unused arguments won't affect anything.
|
||||
So feel free to register any argument in graphgym.contrib.config
|
||||
2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name
|
||||
2) We support *at most* two levels of configs, e.g., scgnn_cfg.dataset.name
|
||||
:return: configuration use by the experiment.
|
||||
"""
|
||||
if explaining_cfg is None:
|
||||
return explaining_cfg
|
||||
if scgnn_cfg is None:
|
||||
return scgnn_cfg
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Basic options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
|
||||
explaining_cfg.out_dir = "./explanations"
|
||||
|
||||
explaining_cfg.cfg_dest = "explaining_config.yaml"
|
||||
|
||||
explaining_cfg.seed = 0
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
explaining_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Model options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.model = CN()
|
||||
|
||||
# Set wether or not load the best model for given dataset or a path
|
||||
explaining_cfg.model.ckpt = "best"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explainer options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.explainer = CN()
|
||||
|
||||
# Name of the explaining method
|
||||
explaining_cfg.explainer.name = "EiXGNN"
|
||||
|
||||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
explaining_cfg.explainer.cfg = "default"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
# 'ExplanationType : 'model' or 'phenomenon'
|
||||
explaining_cfg.explanation_type = "model"
|
||||
|
||||
explaining_cfg.model_config = CN()
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
explaining_cfg.model_config.mode = None
|
||||
|
||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||
explaining_cfg.model_config.task_level = None
|
||||
|
||||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
explaining_cfg.model_config.return_type = "raw"
|
||||
|
||||
explaining_cfg.threshold_config = CN()
|
||||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
|
||||
explaining_cfg.threshold_config.value = 0.5
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# Output directory
|
||||
explaining_cfg.out_dir = "results"
|
||||
|
||||
# Config name (in out_dir)
|
||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||
|
||||
# Random seed
|
||||
explaining_cfg.seed = 0
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Globally shared variables:
|
||||
# These variables will be set dynamically based on the input dataset
|
||||
# Do not directly set them here or in .yaml files
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.share = CN()
|
||||
|
||||
# Size of input dimension
|
||||
explaining_cfg.share.dim_in = 1
|
||||
|
||||
# Size of out dimension, i.e., number of labels to be predicted
|
||||
explaining_cfg.share.dim_out = 1
|
||||
|
||||
# Number of dataset splits: train/val/test
|
||||
explaining_cfg.share.num_splits = 1
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
# Name of the dataset
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
# if PyG: look for it in Pytorch Geometric dataset
|
||||
# if NetworkX/nx: load data in NetworkX format
|
||||
|
||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
||||
# cache dir
|
||||
explaining_cfg.dataset.dir = "./datasets"
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Memory options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.mem = CN()
|
||||
|
||||
# Perform ReLU inplace
|
||||
explaining_cfg.mem.inplace = False
|
||||
|
||||
# Set user customized explaining_cfgs
|
||||
for func in register.config_dict.values():
|
||||
func(explaining_cfg)
|
||||
scgnn_cfg.depth = 'all'
|
||||
scgnn_cfg.interest_map_norm = True
|
||||
scgnn_cfg.score_map_norm = True
|
||||
|
||||
|
||||
def assert_explaining_cfg(explaining_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
||||
|
||||
|
||||
def assert_scgnn_cfg(scgnn_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs
|
||||
TODO
|
||||
|
||||
"""
|
||||
if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
|
||||
"edge, graph, link_pred".format(scgnn_cfg.dataset.task)
|
||||
)
|
||||
if (
|
||||
"classification" in explaining_cfg.dataset.task_type
|
||||
and explaining_cfg.model.loss_fun == "mse"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "cross_entropy"
|
||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
||||
if (
|
||||
explaining_cfg.dataset.task_type == "regression"
|
||||
and explaining_cfg.model.loss_fun == "cross_entropy"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "mse"
|
||||
logging.warning("model.loss_fun changed to mse for regression.")
|
||||
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
|
||||
explaining_cfg.dataset.transductive = False
|
||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
||||
if explaining_cfg.gnn.layers_post_mp < 1:
|
||||
explaining_cfg.gnn.layers_post_mp = 1
|
||||
logging.warning("Layers after message passing should be >=1")
|
||||
if explaining_cfg.gnn.head == "default":
|
||||
explaining_cfg.gnn.head = explaining_cfg.dataset.task
|
||||
explaining_cfg.run_dir = explaining_cfg.out_dir
|
||||
scgnn_cfg.run_dir = scgnn_cfg.out_dir
|
||||
|
||||
|
||||
def dump_explaining_cfg(explaining_cfg):
|
||||
def dump_scgnn_cfg(scgnn_cfg,path):
|
||||
r"""
|
||||
TODO
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`explaining_cfg.out_dir`
|
||||
:obj:`scgnn_cfg.out_dir`
|
||||
Args:
|
||||
explaining_cfg (CfgNode): Configuration node
|
||||
scgnn_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(explaining_cfg.out_dir)
|
||||
explaining_cfg_file = os.path.join(
|
||||
explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest
|
||||
makedirs(scgnn_cfg.out_dir)
|
||||
scgnn_cfg_file = os.path.join(
|
||||
scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest
|
||||
)
|
||||
with open(explaining_cfg_file, "w") as f:
|
||||
explaining_cfg.dump(stream=f)
|
||||
with open(scgnn_cfg_file, "w") as f:
|
||||
scgnn_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_explaining_cfg(explaining_cfg, args):
|
||||
def load_scgnn_cfg(scgnn_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
explaining_cfg (CfgNode): Configuration node
|
||||
scgnn_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
explaining_cfg.merge_from_file(args.explaining_cfg_file)
|
||||
explaining_cfg.merge_from_list(args.opts)
|
||||
assert_explaining_cfg(explaining_cfg)
|
||||
scgnn_cfg.merge_from_file(args.scgnn_cfg_file)
|
||||
scgnn_cfg.merge_from_list(args.opts)
|
||||
assert_scgnn_cfg(scgnn_cfg)
|
||||
|
||||
|
||||
def makedirs_rm_exist(dir):
|
||||
@ -243,34 +108,34 @@ def set_out_dir(out_dir, fname):
|
||||
r"""
|
||||
Create the directory for full experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
||||
out_dir (string): Directory for output, specified in :obj:`scgnn_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
fname = get_fname(fname)
|
||||
explaining_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
scgnn_cfg.out_dir = os.path.join(out_dir, fname)
|
||||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
||||
if scgnn_cfg.train.auto_resume:
|
||||
os.makedirs(scgnn_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.out_dir)
|
||||
makedirs_rm_exist(scgnn_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
r"""
|
||||
Create the directory for each random seed experiment run
|
||||
Args:
|
||||
out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir`
|
||||
out_dir (string): Directory for output, specified in :obj:`scgnn_cfg.out_dir`
|
||||
fname (string): Filename for the yaml format configuration file
|
||||
"""
|
||||
explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed))
|
||||
scgnn_cfg.run_dir = os.path.join(out_dir, str(scgnn_cfg.seed))
|
||||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
||||
if scgnn_cfg.train.auto_resume:
|
||||
os.makedirs(scgnn_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||
makedirs_rm_exist(scgnn_cfg.run_dir)
|
||||
|
||||
|
||||
set_explaining_cfg(explaining_cfg)
|
||||
set_scgnn_cfg(scgnn_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
@ -283,22 +148,22 @@ def from_config(func):
|
||||
has_defaults = [p.default != inspect.Parameter.empty for p in params]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, explaining_cfg: Any = None, **kwargs):
|
||||
if explaining_cfg is not None:
|
||||
explaining_cfg = (
|
||||
dict(explaining_cfg)
|
||||
if isinstance(explaining_cfg, Iterable)
|
||||
else asdict(explaining_cfg)
|
||||
def wrapper(*args, scgnn_cfg: Any = None, **kwargs):
|
||||
if scgnn_cfg is not None:
|
||||
scgnn_cfg = (
|
||||
dict(scgnn_cfg)
|
||||
if isinstance(scgnn_cfg, Iterable)
|
||||
else asdict(scgnn_cfg)
|
||||
)
|
||||
|
||||
iterator = zip(arg_names[len(args) :], has_defaults[len(args) :])
|
||||
for arg_name, has_default in iterator:
|
||||
if arg_name in kwargs:
|
||||
continue
|
||||
elif arg_name in explaining_cfg:
|
||||
kwargs[arg_name] = explaining_cfg[arg_name]
|
||||
elif arg_name in scgnn_cfg:
|
||||
kwargs[arg_name] = scgnn_cfg[arg_name]
|
||||
elif not has_default:
|
||||
raise ValueError(f"'explaining_cfg.{arg_name}' undefined")
|
||||
raise ValueError(f"'scgnn_cfg.{arg_name}' undefined")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -24,7 +24,7 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
def set_explaining_cfg(explaining_cfg):
|
||||
def set_cfg(explaining_cfg):
|
||||
r"""
|
||||
This function sets the default config value.
|
||||
1) Note that for an experiment, only part of the arguments will be used
|
||||
@ -110,89 +110,23 @@ def set_explaining_cfg(explaining_cfg):
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# Output directory
|
||||
explaining_cfg.out_dir = "results"
|
||||
|
||||
# Config name (in out_dir)
|
||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||
|
||||
# Random seed
|
||||
explaining_cfg.seed = 0
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Globally shared variables:
|
||||
# These variables will be set dynamically based on the input dataset
|
||||
# Do not directly set them here or in .yaml files
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.share = CN()
|
||||
|
||||
# Size of input dimension
|
||||
explaining_cfg.share.dim_in = 1
|
||||
|
||||
# Size of out dimension, i.e., number of labels to be predicted
|
||||
explaining_cfg.share.dim_out = 1
|
||||
|
||||
# Number of dataset splits: train/val/test
|
||||
explaining_cfg.share.num_splits = 1
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Dataset options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.dataset = CN()
|
||||
|
||||
# Name of the dataset
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
# if PyG: look for it in Pytorch Geometric dataset
|
||||
# if NetworkX/nx: load data in NetworkX format
|
||||
|
||||
# Dir to load the dataset. If the dataset is downloaded, this is the
|
||||
# cache dir
|
||||
explaining_cfg.dataset.dir = "./datasets"
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Memory options
|
||||
# ----------------------------------------------------------------------- #
|
||||
explaining_cfg.mem = CN()
|
||||
|
||||
# Perform ReLU inplace
|
||||
explaining_cfg.mem.inplace = False
|
||||
|
||||
# Set user customized explaining_cfgs
|
||||
for func in register.config_dict.values():
|
||||
func(explaining_cfg)
|
||||
explaining_cfg.relu_and_normalize = True
|
||||
|
||||
|
||||
def assert_explaining_cfg(explaining_cfg):
|
||||
|
||||
def assert_cfg(explaining_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(explaining_cfg.dataset.task)
|
||||
)
|
||||
if (
|
||||
"classification" in explaining_cfg.dataset.task_type
|
||||
and explaining_cfg.model.loss_fun == "mse"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "cross_entropy"
|
||||
logging.warning("model.loss_fun changed to cross_entropy for classification.")
|
||||
if (
|
||||
explaining_cfg.dataset.task_type == "regression"
|
||||
and explaining_cfg.model.loss_fun == "cross_entropy"
|
||||
):
|
||||
explaining_cfg.model.loss_fun = "mse"
|
||||
logging.warning("model.loss_fun changed to mse for regression.")
|
||||
if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive:
|
||||
explaining_cfg.dataset.transductive = False
|
||||
logging.warning("dataset.transductive changed " "to False for graph task.")
|
||||
if explaining_cfg.gnn.layers_post_mp < 1:
|
||||
explaining_cfg.gnn.layers_post_mp = 1
|
||||
logging.warning("Layers after message passing should be >=1")
|
||||
if explaining_cfg.gnn.head == "default":
|
||||
explaining_cfg.gnn.head = explaining_cfg.dataset.task
|
||||
explaining_cfg.run_dir = explaining_cfg.out_dir
|
||||
|
||||
|
||||
def dump_explaining_cfg(explaining_cfg):
|
||||
def dump_cfg(explaining_cfg):
|
||||
r"""
|
||||
Dumps the config to the output directory specified in
|
||||
:obj:`explaining_cfg.out_dir`
|
||||
@ -207,7 +141,7 @@ def dump_explaining_cfg(explaining_cfg):
|
||||
explaining_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_explaining_cfg(explaining_cfg, args):
|
||||
def load_cfg(explaining_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
@ -270,7 +204,7 @@ def set_run_dir(out_dir):
|
||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||
|
||||
|
||||
set_explaining_cfg(explaining_cfg)
|
||||
set_cfg(explaining_cfg)
|
||||
|
||||
|
||||
def from_config(func):
|
||||
|
41
explaining_framework/utils/explaining/load_model.py
Normal file
41
explaining_framework/utils/explaining/load_model.py
Normal file
@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
||||
|
||||
def load_ckpt(
|
||||
model: torch.nn.Module,
|
||||
ckpt_path: str,
|
||||
) -> torch.nn.Module:
|
||||
r"""Loads the model at given checkpoint."""
|
||||
|
||||
if not osp.exists(path):
|
||||
return None
|
||||
|
||||
ckpt = torch.load(ckpt_path)
|
||||
model.load_state_dict(ckpt[MODEL_STATE])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_best_given_exp(path_to_xp:str, wrt_metric:str:'val')->str:
|
||||
path = os.path.normpath(path)
|
||||
path.split(os.sep)
|
||||
for path in glob.glob(os.path.join(path_to_xp,'[0-9]'*10,wrt_metric,'stats.json')):
|
||||
print(path)
|
||||
|
||||
|
||||
|
||||
|
@ -43,6 +43,7 @@ def load_explanation(path: str) -> Explanation:
|
||||
|
||||
|
||||
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
|
||||
exp = copy.copy(exp)
|
||||
data = exp.to_dict()
|
||||
for k, v in data.items():
|
||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||
|
Loading…
Reference in New Issue
Block a user