New features
This commit is contained in:
parent
5224207466
commit
9397934825
8 changed files with 72 additions and 77 deletions
|
|
@ -57,6 +57,8 @@ def set_cfg(explaining_cfg):
|
|||
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
explaining_cfg.dataset.specific_items = None
|
||||
|
||||
explaining_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
|
@ -80,6 +82,9 @@ def set_cfg(explaining_cfg):
|
|||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
explaining_cfg.explainer.cfg = "default"
|
||||
|
||||
# Whether or not recomputing explanation if they already exist
|
||||
explaining_cfg.explainer.force = False
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
|
@ -98,27 +103,26 @@ def set_cfg(explaining_cfg):
|
|||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
explaining_cfg.model_config.return_type = "raw"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Thresholding options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.threshold_config = CN()
|
||||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
|
||||
explaining_cfg.threshold_config.value = 0.5
|
||||
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# Config name (in out_dir)
|
||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||
|
||||
explaining_cfg.seed = 0
|
||||
|
||||
explaining_cfg.dataset.dir = "./datasets"
|
||||
|
||||
explaining_cfg.relu_and_normalize = True
|
||||
# which objectives metrics to computes, either all or one in particular if implemented
|
||||
explaining_cfg.metrics = "all"
|
||||
|
||||
# Whether or not recomputing metrics if they already exist
|
||||
explaining_cfg.metrics.force = False
|
||||
|
||||
|
||||
def assert_cfg(explaining_cfg):
|
||||
|
|
@ -185,8 +189,6 @@ def set_out_dir(out_dir, fname):
|
|||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
|
|
@ -200,8 +202,6 @@ def set_run_dir(out_dir):
|
|||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||
|
||||
|
||||
set_cfg(explaining_cfg)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue