New features

This commit is contained in:
araison 2022-12-26 15:01:18 +01:00
parent 5224207466
commit 9397934825
8 changed files with 72 additions and 77 deletions

View file

@ -57,6 +57,8 @@ def set_cfg(explaining_cfg):
explaining_cfg.dataset.name = "Cora"
explaining_cfg.dataset.specific_items = None
explaining_cfg.run_topological_stat = True
# ----------------------------------------------------------------------- #
@ -80,6 +82,9 @@ def set_cfg(explaining_cfg):
# Whether or not to provide specific explaining methods configuration or default configuration
explaining_cfg.explainer.cfg = "default"
# Whether or not recomputing explanation if they already exist
explaining_cfg.explainer.force = False
# ----------------------------------------------------------------------- #
# Explaining options
# ----------------------------------------------------------------------- #
@ -98,27 +103,26 @@ def set_cfg(explaining_cfg):
# Do not modify it, we always assume here that model output are 'raw'
explaining_cfg.model_config.return_type = "raw"
# ----------------------------------------------------------------------- #
# Thresholding options
# ----------------------------------------------------------------------- #
explaining_cfg.threshold_config = CN()
explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold_config.value = 0.5
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
# Set print destination: stdout / file / both
explaining_cfg.print = "both"
explaining_cfg.threshold_config.relu_and_normalize = True
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto"
# Config name (in out_dir)
explaining_cfg.explaining_cfg_dest = "config.yaml"
explaining_cfg.seed = 0
explaining_cfg.dataset.dir = "./datasets"
explaining_cfg.relu_and_normalize = True
# which objectives metrics to computes, either all or one in particular if implemented
explaining_cfg.metrics = "all"
# Whether or not recomputing metrics if they already exist
explaining_cfg.metrics.force = False
def assert_cfg(explaining_cfg):
@ -185,8 +189,6 @@ def set_out_dir(out_dir, fname):
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.out_dir)
def set_run_dir(out_dir):
@ -200,8 +202,6 @@ def set_run_dir(out_dir):
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.run_dir)
set_cfg(explaining_cfg)