diff --git a/explaining_framework/config/explainer_config/eixgnn_config.py b/explaining_framework/config/explainer_config/eixgnn_config.py new file mode 100644 index 0000000..d5106d8 --- /dev/null +++ b/explaining_framework/config/explainer_config/eixgnn_config.py @@ -0,0 +1,302 @@ +import functools +import inspect +import logging +import os +import shutil +import warnings +from collections.abc import Iterable +from dataclasses import asdict +from typing import Any + +import torch_geometric.graphgym.register as register +from torch_geometric.data.makedirs import makedirs + +try: # Define global config object + from yacs.config import CfgNode as CN + + eixgnn_cfg = CN() +except ImportError: + eixgnn_cfg = None + warnings.warn( + "Could not define global config object. Please install " + "'yacs' for using the GraphGym experiment manager via " + "'pip install yacs'." + ) + + +def set_eixgnn_cfg(eixgnn_cfg): + r""" + This function sets the default config value. + 1) Note that for an experiment, only part of the arguments will be used + The remaining unused arguments won't affect anything. + So feel free to register any argument in graphgym.contrib.config + 2) We support *at most* two levels of configs, e.g., eixgnn_cfg.dataset.name + :return: configuration use by the experiment. + """ + if eixgnn_cfg is None: + return eixgnn_cfg + + # ----------------------------------------------------------------------- # + # Basic options + # ----------------------------------------------------------------------- # + + # Set print destination: stdout / file / both + eixgnn_cfg.print = "both" + + eixgnn_cfg.out_dir = "./explanations" + + eixgnn_cfg.cfg_dest = "explaining_config.yaml" + + eixgnn_cfg.seed = 0 + + # ----------------------------------------------------------------------- # + # Dataset options + # ----------------------------------------------------------------------- # + + eixgnn_cfg.dataset = CN() + + eixgnn_cfg.dataset.name = "Cora" + + eixgnn_cfg.run_topological_stat = True + + # ----------------------------------------------------------------------- # + # Model options + # ----------------------------------------------------------------------- # + + eixgnn_cfg.model = CN() + + # Set wether or not load the best model for given dataset or a path + eixgnn_cfg.model.ckpt = "best" + + # ----------------------------------------------------------------------- # + # Explainer options + # ----------------------------------------------------------------------- # + + eixgnn_cfg.explainer = CN() + + # Name of the explaining method + eixgnn_cfg.explainer.name = "EiXGNN" + + # Whether or not to provide specific explaining methods configuration or default configuration + eixgnn_cfg.explainer.cfg = "default" + + # ----------------------------------------------------------------------- # + # Explaining options + # ----------------------------------------------------------------------- # + + # 'ExplanationType : 'model' or 'phenomenon' + eixgnn_cfg.explanation_type = "model" + + eixgnn_cfg.model_config = CN() + + # Do not modify it, will be handled by dataset , assuming one dataset = one learning task + eixgnn_cfg.model_config.mode = None + + # Do not modify it, will be handled by dataset , assuming one dataset = one learning task + eixgnn_cfg.model_config.task_level = None + + # Do not modify it, we always assume here that model output are 'raw' + eixgnn_cfg.model_config.return_type = "raw" + + eixgnn_cfg.threshold_config = CN() + + eixgnn_cfg.threshold_config.threshold_type = None + + eixgnn_cfg.threshold_config.value = 0.5 + + # Set print destination: stdout / file / both + eixgnn_cfg.print = "both" + + # Select device: 'cpu', 'cuda', 'auto' + eixgnn_cfg.accelerator = "auto" + + # Output directory + eixgnn_cfg.out_dir = "results" + + # Config name (in out_dir) + eixgnn_cfg.eixgnn_cfg_dest = "config.yaml" + + # Random seed + eixgnn_cfg.seed = 0 + # ----------------------------------------------------------------------- # + # Globally shared variables: + # These variables will be set dynamically based on the input dataset + # Do not directly set them here or in .yaml files + # ----------------------------------------------------------------------- # + + eixgnn_cfg.share = CN() + + # Size of input dimension + eixgnn_cfg.share.dim_in = 1 + + # Size of out dimension, i.e., number of labels to be predicted + eixgnn_cfg.share.dim_out = 1 + + # Number of dataset splits: train/val/test + eixgnn_cfg.share.num_splits = 1 + + # ----------------------------------------------------------------------- # + # Dataset options + # ----------------------------------------------------------------------- # + eixgnn_cfg.dataset = CN() + + # Name of the dataset + eixgnn_cfg.dataset.name = "Cora" + + # if PyG: look for it in Pytorch Geometric dataset + # if NetworkX/nx: load data in NetworkX format + + # Dir to load the dataset. If the dataset is downloaded, this is the + # cache dir + eixgnn_cfg.dataset.dir = "./datasets" + # ----------------------------------------------------------------------- # + # Memory options + # ----------------------------------------------------------------------- # + eixgnn_cfg.mem = CN() + + # Perform ReLU inplace + eixgnn_cfg.mem.inplace = False + + # Set user customized eixgnn_cfgs + for func in register.config_dict.values(): + func(eixgnn_cfg) + + +def assert_eixgnn_cfg(eixgnn_cfg): + r"""Checks config values, do necessary post processing to the configs""" + if eixgnn_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]: + raise ValueError( + "Task {} not supported, must be one of node, " + "edge, graph, link_pred".format(eixgnn_cfg.dataset.task) + ) + if ( + "classification" in eixgnn_cfg.dataset.task_type + and eixgnn_cfg.model.loss_fun == "mse" + ): + eixgnn_cfg.model.loss_fun = "cross_entropy" + logging.warning("model.loss_fun changed to cross_entropy for classification.") + if ( + eixgnn_cfg.dataset.task_type == "regression" + and eixgnn_cfg.model.loss_fun == "cross_entropy" + ): + eixgnn_cfg.model.loss_fun = "mse" + logging.warning("model.loss_fun changed to mse for regression.") + if eixgnn_cfg.dataset.task == "graph" and eixgnn_cfg.dataset.transductive: + eixgnn_cfg.dataset.transductive = False + logging.warning("dataset.transductive changed " "to False for graph task.") + if eixgnn_cfg.gnn.layers_post_mp < 1: + eixgnn_cfg.gnn.layers_post_mp = 1 + logging.warning("Layers after message passing should be >=1") + if eixgnn_cfg.gnn.head == "default": + eixgnn_cfg.gnn.head = eixgnn_cfg.dataset.task + eixgnn_cfg.run_dir = eixgnn_cfg.out_dir + + +def dump_eixgnn_cfg(eixgnn_cfg): + r""" + Dumps the config to the output directory specified in + :obj:`eixgnn_cfg.out_dir` + Args: + eixgnn_cfg (CfgNode): Configuration node + """ + makedirs(eixgnn_cfg.out_dir) + eixgnn_cfg_file = os.path.join(eixgnn_cfg.out_dir, eixgnn_cfg.eixgnn_cfg_dest) + with open(eixgnn_cfg_file, "w") as f: + eixgnn_cfg.dump(stream=f) + + +def load_eixgnn_cfg(eixgnn_cfg, args): + r""" + Load configurations from file system and command line + Args: + eixgnn_cfg (CfgNode): Configuration node + args (ArgumentParser): Command argument parser + """ + eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file) + eixgnn_cfg.merge_from_list(args.opts) + assert_eixgnn_cfg(eixgnn_cfg) + + +def makedirs_rm_exist(dir): + if os.path.isdir(dir): + shutil.rmtree(dir) + os.makedirs(dir, exist_ok=True) + + +def get_fname(fname): + r""" + Extract filename from file name path + Args: + fname (string): Filename for the yaml format configuration file + """ + fname = fname.split("/")[-1] + if fname.endswith(".yaml"): + fname = fname[:-5] + elif fname.endswith(".yml"): + fname = fname[:-4] + return fname + + +def set_out_dir(out_dir, fname): + r""" + Create the directory for full experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + fname = get_fname(fname) + eixgnn_cfg.out_dir = os.path.join(out_dir, fname) + # Make output directory + if eixgnn_cfg.train.auto_resume: + os.makedirs(eixgnn_cfg.out_dir, exist_ok=True) + else: + makedirs_rm_exist(eixgnn_cfg.out_dir) + + +def set_run_dir(out_dir): + r""" + Create the directory for each random seed experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`eixgnn_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + eixgnn_cfg.run_dir = os.path.join(out_dir, str(eixgnn_cfg.seed)) + # Make output directory + if eixgnn_cfg.train.auto_resume: + os.makedirs(eixgnn_cfg.run_dir, exist_ok=True) + else: + makedirs_rm_exist(eixgnn_cfg.run_dir) + + +set_eixgnn_cfg(eixgnn_cfg) + + +def from_config(func): + if inspect.isclass(func): + params = list(inspect.signature(func.__init__).parameters.values())[1:] + else: + params = list(inspect.signature(func).parameters.values()) + + arg_names = [p.name for p in params] + has_defaults = [p.default != inspect.Parameter.empty for p in params] + + @functools.wraps(func) + def wrapper(*args, eixgnn_cfg: Any = None, **kwargs): + if eixgnn_cfg is not None: + eixgnn_cfg = ( + dict(eixgnn_cfg) + if isinstance(eixgnn_cfg, Iterable) + else asdict(eixgnn_cfg) + ) + + iterator = zip(arg_names[len(args) :], has_defaults[len(args) :]) + for arg_name, has_default in iterator: + if arg_name in kwargs: + continue + elif arg_name in eixgnn_cfg: + kwargs[arg_name] = eixgnn_cfg[arg_name] + elif not has_default: + raise ValueError(f"'eixgnn_cfg.{arg_name}' undefined") + return func(*args, **kwargs) + + return wrapper diff --git a/explaining_framework/config/explainer_config/scgnn_config.py b/explaining_framework/config/explainer_config/scgnn_config.py new file mode 100644 index 0000000..660b1ea --- /dev/null +++ b/explaining_framework/config/explainer_config/scgnn_config.py @@ -0,0 +1,304 @@ +import functools +import inspect +import logging +import os +import shutil +import warnings +from collections.abc import Iterable +from dataclasses import asdict +from typing import Any + +import torch_geometric.graphgym.register as register +from torch_geometric.data.makedirs import makedirs + +try: # Define global config object + from yacs.config import CfgNode as CN + + explaining_cfg = CN() +except ImportError: + explaining_cfg = None + warnings.warn( + "Could not define global config object. Please install " + "'yacs' for using the GraphGym experiment manager via " + "'pip install yacs'." + ) + + +def set_explaining_cfg(explaining_cfg): + r""" + This function sets the default config value. + 1) Note that for an experiment, only part of the arguments will be used + The remaining unused arguments won't affect anything. + So feel free to register any argument in graphgym.contrib.config + 2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name + :return: configuration use by the experiment. + """ + if explaining_cfg is None: + return explaining_cfg + + # ----------------------------------------------------------------------- # + # Basic options + # ----------------------------------------------------------------------- # + + # Set print destination: stdout / file / both + explaining_cfg.print = "both" + + explaining_cfg.out_dir = "./explanations" + + explaining_cfg.cfg_dest = "explaining_config.yaml" + + explaining_cfg.seed = 0 + + # ----------------------------------------------------------------------- # + # Dataset options + # ----------------------------------------------------------------------- # + + explaining_cfg.dataset = CN() + + explaining_cfg.dataset.name = "Cora" + + explaining_cfg.run_topological_stat = True + + # ----------------------------------------------------------------------- # + # Model options + # ----------------------------------------------------------------------- # + + explaining_cfg.model = CN() + + # Set wether or not load the best model for given dataset or a path + explaining_cfg.model.ckpt = "best" + + # ----------------------------------------------------------------------- # + # Explainer options + # ----------------------------------------------------------------------- # + + explaining_cfg.explainer = CN() + + # Name of the explaining method + explaining_cfg.explainer.name = "EiXGNN" + + # Whether or not to provide specific explaining methods configuration or default configuration + explaining_cfg.explainer.cfg = "default" + + # ----------------------------------------------------------------------- # + # Explaining options + # ----------------------------------------------------------------------- # + + # 'ExplanationType : 'model' or 'phenomenon' + explaining_cfg.explanation_type = "model" + + explaining_cfg.model_config = CN() + + # Do not modify it, will be handled by dataset , assuming one dataset = one learning task + explaining_cfg.model_config.mode = None + + # Do not modify it, will be handled by dataset , assuming one dataset = one learning task + explaining_cfg.model_config.task_level = None + + # Do not modify it, we always assume here that model output are 'raw' + explaining_cfg.model_config.return_type = "raw" + + explaining_cfg.threshold_config = CN() + + explaining_cfg.threshold_config.threshold_type = None + + explaining_cfg.threshold_config.value = 0.5 + + # Set print destination: stdout / file / both + explaining_cfg.print = "both" + + # Select device: 'cpu', 'cuda', 'auto' + explaining_cfg.accelerator = "auto" + + # Output directory + explaining_cfg.out_dir = "results" + + # Config name (in out_dir) + explaining_cfg.explaining_cfg_dest = "config.yaml" + + # Random seed + explaining_cfg.seed = 0 + # ----------------------------------------------------------------------- # + # Globally shared variables: + # These variables will be set dynamically based on the input dataset + # Do not directly set them here or in .yaml files + # ----------------------------------------------------------------------- # + + explaining_cfg.share = CN() + + # Size of input dimension + explaining_cfg.share.dim_in = 1 + + # Size of out dimension, i.e., number of labels to be predicted + explaining_cfg.share.dim_out = 1 + + # Number of dataset splits: train/val/test + explaining_cfg.share.num_splits = 1 + + # ----------------------------------------------------------------------- # + # Dataset options + # ----------------------------------------------------------------------- # + explaining_cfg.dataset = CN() + + # Name of the dataset + explaining_cfg.dataset.name = "Cora" + + # if PyG: look for it in Pytorch Geometric dataset + # if NetworkX/nx: load data in NetworkX format + + # Dir to load the dataset. If the dataset is downloaded, this is the + # cache dir + explaining_cfg.dataset.dir = "./datasets" + # ----------------------------------------------------------------------- # + # Memory options + # ----------------------------------------------------------------------- # + explaining_cfg.mem = CN() + + # Perform ReLU inplace + explaining_cfg.mem.inplace = False + + # Set user customized explaining_cfgs + for func in register.config_dict.values(): + func(explaining_cfg) + + +def assert_explaining_cfg(explaining_cfg): + r"""Checks config values, do necessary post processing to the configs""" + if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]: + raise ValueError( + "Task {} not supported, must be one of node, " + "edge, graph, link_pred".format(explaining_cfg.dataset.task) + ) + if ( + "classification" in explaining_cfg.dataset.task_type + and explaining_cfg.model.loss_fun == "mse" + ): + explaining_cfg.model.loss_fun = "cross_entropy" + logging.warning("model.loss_fun changed to cross_entropy for classification.") + if ( + explaining_cfg.dataset.task_type == "regression" + and explaining_cfg.model.loss_fun == "cross_entropy" + ): + explaining_cfg.model.loss_fun = "mse" + logging.warning("model.loss_fun changed to mse for regression.") + if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive: + explaining_cfg.dataset.transductive = False + logging.warning("dataset.transductive changed " "to False for graph task.") + if explaining_cfg.gnn.layers_post_mp < 1: + explaining_cfg.gnn.layers_post_mp = 1 + logging.warning("Layers after message passing should be >=1") + if explaining_cfg.gnn.head == "default": + explaining_cfg.gnn.head = explaining_cfg.dataset.task + explaining_cfg.run_dir = explaining_cfg.out_dir + + +def dump_explaining_cfg(explaining_cfg): + r""" + Dumps the config to the output directory specified in + :obj:`explaining_cfg.out_dir` + Args: + explaining_cfg (CfgNode): Configuration node + """ + makedirs(explaining_cfg.out_dir) + explaining_cfg_file = os.path.join( + explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest + ) + with open(explaining_cfg_file, "w") as f: + explaining_cfg.dump(stream=f) + + +def load_explaining_cfg(explaining_cfg, args): + r""" + Load configurations from file system and command line + Args: + explaining_cfg (CfgNode): Configuration node + args (ArgumentParser): Command argument parser + """ + explaining_cfg.merge_from_file(args.explaining_cfg_file) + explaining_cfg.merge_from_list(args.opts) + assert_explaining_cfg(explaining_cfg) + + +def makedirs_rm_exist(dir): + if os.path.isdir(dir): + shutil.rmtree(dir) + os.makedirs(dir, exist_ok=True) + + +def get_fname(fname): + r""" + Extract filename from file name path + Args: + fname (string): Filename for the yaml format configuration file + """ + fname = fname.split("/")[-1] + if fname.endswith(".yaml"): + fname = fname[:-5] + elif fname.endswith(".yml"): + fname = fname[:-4] + return fname + + +def set_out_dir(out_dir, fname): + r""" + Create the directory for full experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + fname = get_fname(fname) + explaining_cfg.out_dir = os.path.join(out_dir, fname) + # Make output directory + if explaining_cfg.train.auto_resume: + os.makedirs(explaining_cfg.out_dir, exist_ok=True) + else: + makedirs_rm_exist(explaining_cfg.out_dir) + + +def set_run_dir(out_dir): + r""" + Create the directory for each random seed experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed)) + # Make output directory + if explaining_cfg.train.auto_resume: + os.makedirs(explaining_cfg.run_dir, exist_ok=True) + else: + makedirs_rm_exist(explaining_cfg.run_dir) + + +set_explaining_cfg(explaining_cfg) + + +def from_config(func): + if inspect.isclass(func): + params = list(inspect.signature(func.__init__).parameters.values())[1:] + else: + params = list(inspect.signature(func).parameters.values()) + + arg_names = [p.name for p in params] + has_defaults = [p.default != inspect.Parameter.empty for p in params] + + @functools.wraps(func) + def wrapper(*args, explaining_cfg: Any = None, **kwargs): + if explaining_cfg is not None: + explaining_cfg = ( + dict(explaining_cfg) + if isinstance(explaining_cfg, Iterable) + else asdict(explaining_cfg) + ) + + iterator = zip(arg_names[len(args) :], has_defaults[len(args) :]) + for arg_name, has_default in iterator: + if arg_name in kwargs: + continue + elif arg_name in explaining_cfg: + kwargs[arg_name] = explaining_cfg[arg_name] + elif not has_default: + raise ValueError(f"'explaining_cfg.{arg_name}' undefined") + return func(*args, **kwargs) + + return wrapper diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py new file mode 100644 index 0000000..660b1ea --- /dev/null +++ b/explaining_framework/config/explaining_config.py @@ -0,0 +1,304 @@ +import functools +import inspect +import logging +import os +import shutil +import warnings +from collections.abc import Iterable +from dataclasses import asdict +from typing import Any + +import torch_geometric.graphgym.register as register +from torch_geometric.data.makedirs import makedirs + +try: # Define global config object + from yacs.config import CfgNode as CN + + explaining_cfg = CN() +except ImportError: + explaining_cfg = None + warnings.warn( + "Could not define global config object. Please install " + "'yacs' for using the GraphGym experiment manager via " + "'pip install yacs'." + ) + + +def set_explaining_cfg(explaining_cfg): + r""" + This function sets the default config value. + 1) Note that for an experiment, only part of the arguments will be used + The remaining unused arguments won't affect anything. + So feel free to register any argument in graphgym.contrib.config + 2) We support *at most* two levels of configs, e.g., explaining_cfg.dataset.name + :return: configuration use by the experiment. + """ + if explaining_cfg is None: + return explaining_cfg + + # ----------------------------------------------------------------------- # + # Basic options + # ----------------------------------------------------------------------- # + + # Set print destination: stdout / file / both + explaining_cfg.print = "both" + + explaining_cfg.out_dir = "./explanations" + + explaining_cfg.cfg_dest = "explaining_config.yaml" + + explaining_cfg.seed = 0 + + # ----------------------------------------------------------------------- # + # Dataset options + # ----------------------------------------------------------------------- # + + explaining_cfg.dataset = CN() + + explaining_cfg.dataset.name = "Cora" + + explaining_cfg.run_topological_stat = True + + # ----------------------------------------------------------------------- # + # Model options + # ----------------------------------------------------------------------- # + + explaining_cfg.model = CN() + + # Set wether or not load the best model for given dataset or a path + explaining_cfg.model.ckpt = "best" + + # ----------------------------------------------------------------------- # + # Explainer options + # ----------------------------------------------------------------------- # + + explaining_cfg.explainer = CN() + + # Name of the explaining method + explaining_cfg.explainer.name = "EiXGNN" + + # Whether or not to provide specific explaining methods configuration or default configuration + explaining_cfg.explainer.cfg = "default" + + # ----------------------------------------------------------------------- # + # Explaining options + # ----------------------------------------------------------------------- # + + # 'ExplanationType : 'model' or 'phenomenon' + explaining_cfg.explanation_type = "model" + + explaining_cfg.model_config = CN() + + # Do not modify it, will be handled by dataset , assuming one dataset = one learning task + explaining_cfg.model_config.mode = None + + # Do not modify it, will be handled by dataset , assuming one dataset = one learning task + explaining_cfg.model_config.task_level = None + + # Do not modify it, we always assume here that model output are 'raw' + explaining_cfg.model_config.return_type = "raw" + + explaining_cfg.threshold_config = CN() + + explaining_cfg.threshold_config.threshold_type = None + + explaining_cfg.threshold_config.value = 0.5 + + # Set print destination: stdout / file / both + explaining_cfg.print = "both" + + # Select device: 'cpu', 'cuda', 'auto' + explaining_cfg.accelerator = "auto" + + # Output directory + explaining_cfg.out_dir = "results" + + # Config name (in out_dir) + explaining_cfg.explaining_cfg_dest = "config.yaml" + + # Random seed + explaining_cfg.seed = 0 + # ----------------------------------------------------------------------- # + # Globally shared variables: + # These variables will be set dynamically based on the input dataset + # Do not directly set them here or in .yaml files + # ----------------------------------------------------------------------- # + + explaining_cfg.share = CN() + + # Size of input dimension + explaining_cfg.share.dim_in = 1 + + # Size of out dimension, i.e., number of labels to be predicted + explaining_cfg.share.dim_out = 1 + + # Number of dataset splits: train/val/test + explaining_cfg.share.num_splits = 1 + + # ----------------------------------------------------------------------- # + # Dataset options + # ----------------------------------------------------------------------- # + explaining_cfg.dataset = CN() + + # Name of the dataset + explaining_cfg.dataset.name = "Cora" + + # if PyG: look for it in Pytorch Geometric dataset + # if NetworkX/nx: load data in NetworkX format + + # Dir to load the dataset. If the dataset is downloaded, this is the + # cache dir + explaining_cfg.dataset.dir = "./datasets" + # ----------------------------------------------------------------------- # + # Memory options + # ----------------------------------------------------------------------- # + explaining_cfg.mem = CN() + + # Perform ReLU inplace + explaining_cfg.mem.inplace = False + + # Set user customized explaining_cfgs + for func in register.config_dict.values(): + func(explaining_cfg) + + +def assert_explaining_cfg(explaining_cfg): + r"""Checks config values, do necessary post processing to the configs""" + if explaining_cfg.dataset.task not in ["node", "edge", "graph", "link_pred"]: + raise ValueError( + "Task {} not supported, must be one of node, " + "edge, graph, link_pred".format(explaining_cfg.dataset.task) + ) + if ( + "classification" in explaining_cfg.dataset.task_type + and explaining_cfg.model.loss_fun == "mse" + ): + explaining_cfg.model.loss_fun = "cross_entropy" + logging.warning("model.loss_fun changed to cross_entropy for classification.") + if ( + explaining_cfg.dataset.task_type == "regression" + and explaining_cfg.model.loss_fun == "cross_entropy" + ): + explaining_cfg.model.loss_fun = "mse" + logging.warning("model.loss_fun changed to mse for regression.") + if explaining_cfg.dataset.task == "graph" and explaining_cfg.dataset.transductive: + explaining_cfg.dataset.transductive = False + logging.warning("dataset.transductive changed " "to False for graph task.") + if explaining_cfg.gnn.layers_post_mp < 1: + explaining_cfg.gnn.layers_post_mp = 1 + logging.warning("Layers after message passing should be >=1") + if explaining_cfg.gnn.head == "default": + explaining_cfg.gnn.head = explaining_cfg.dataset.task + explaining_cfg.run_dir = explaining_cfg.out_dir + + +def dump_explaining_cfg(explaining_cfg): + r""" + Dumps the config to the output directory specified in + :obj:`explaining_cfg.out_dir` + Args: + explaining_cfg (CfgNode): Configuration node + """ + makedirs(explaining_cfg.out_dir) + explaining_cfg_file = os.path.join( + explaining_cfg.out_dir, explaining_cfg.explaining_cfg_dest + ) + with open(explaining_cfg_file, "w") as f: + explaining_cfg.dump(stream=f) + + +def load_explaining_cfg(explaining_cfg, args): + r""" + Load configurations from file system and command line + Args: + explaining_cfg (CfgNode): Configuration node + args (ArgumentParser): Command argument parser + """ + explaining_cfg.merge_from_file(args.explaining_cfg_file) + explaining_cfg.merge_from_list(args.opts) + assert_explaining_cfg(explaining_cfg) + + +def makedirs_rm_exist(dir): + if os.path.isdir(dir): + shutil.rmtree(dir) + os.makedirs(dir, exist_ok=True) + + +def get_fname(fname): + r""" + Extract filename from file name path + Args: + fname (string): Filename for the yaml format configuration file + """ + fname = fname.split("/")[-1] + if fname.endswith(".yaml"): + fname = fname[:-5] + elif fname.endswith(".yml"): + fname = fname[:-4] + return fname + + +def set_out_dir(out_dir, fname): + r""" + Create the directory for full experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + fname = get_fname(fname) + explaining_cfg.out_dir = os.path.join(out_dir, fname) + # Make output directory + if explaining_cfg.train.auto_resume: + os.makedirs(explaining_cfg.out_dir, exist_ok=True) + else: + makedirs_rm_exist(explaining_cfg.out_dir) + + +def set_run_dir(out_dir): + r""" + Create the directory for each random seed experiment run + Args: + out_dir (string): Directory for output, specified in :obj:`explaining_cfg.out_dir` + fname (string): Filename for the yaml format configuration file + """ + explaining_cfg.run_dir = os.path.join(out_dir, str(explaining_cfg.seed)) + # Make output directory + if explaining_cfg.train.auto_resume: + os.makedirs(explaining_cfg.run_dir, exist_ok=True) + else: + makedirs_rm_exist(explaining_cfg.run_dir) + + +set_explaining_cfg(explaining_cfg) + + +def from_config(func): + if inspect.isclass(func): + params = list(inspect.signature(func.__init__).parameters.values())[1:] + else: + params = list(inspect.signature(func).parameters.values()) + + arg_names = [p.name for p in params] + has_defaults = [p.default != inspect.Parameter.empty for p in params] + + @functools.wraps(func) + def wrapper(*args, explaining_cfg: Any = None, **kwargs): + if explaining_cfg is not None: + explaining_cfg = ( + dict(explaining_cfg) + if isinstance(explaining_cfg, Iterable) + else asdict(explaining_cfg) + ) + + iterator = zip(arg_names[len(args) :], has_defaults[len(args) :]) + for arg_name, has_default in iterator: + if arg_name in kwargs: + continue + elif arg_name in explaining_cfg: + kwargs[arg_name] = explaining_cfg[arg_name] + elif not has_default: + raise ValueError(f"'explaining_cfg.{arg_name}' undefined") + return func(*args, **kwargs) + + return wrapper diff --git a/explaining_framework/utils/explanation_adjust.py b/explaining_framework/utils/explanation_adjust.py deleted file mode 100644 index 0e641ee..0000000 --- a/explaining_framework/utils/explanation_adjust.py +++ /dev/null @@ -1,16 +0,0 @@ -import copy - -from torch import FloatTensor -from torch.nn import ReLU - - -def relu_mask(explanation: Explanation) -> Explanation: - relu = ReLU() - explanation_store = explanation._store - raw_data = copy.copy(explanation._store) - for k, v in explanation_store.items(): - if "mask" in k: - explanation_store[k] = relu(v) - explanation.__setattr__("raw_explanation", raw_data) - explanation.__setattr__("raw_explanation_transform", "relu") - return explanation diff --git a/explaining_framework/utils/explanation_threshold.py b/explaining_framework/utils/explanation_threshold.py deleted file mode 100644 index d5f8c39..0000000 --- a/explaining_framework/utils/explanation_threshold.py +++ /dev/null @@ -1,7 +0,0 @@ -import copy -from typing import Dict, List, Optional, Union - -import torch -from torch import Tensor -from torch_geometric.explain.config import ThresholdConfig, ThresholdType -from torch_geometric.explain.explanation import Explanation