diff --git a/explaining_framework/utils/config_gen.py b/explaining_framework/utils/config_gen.py index c6a1891..2615351 100644 --- a/explaining_framework/utils/config_gen.py +++ b/explaining_framework/utils/config_gen.py @@ -2,12 +2,11 @@ import glob import os import shutil +from explaining_framework.utils.io import read_yaml, write_yaml from torch_geometric.data.makedirs import makedirs from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.utils.io import string_to_python -from explaining_framework.utils.io import read_yaml, write_yaml - if "__main__" == __name__: config_folder = os.path.abspath( os.path.join(os.path.dirname(__name__), "../../", "configs") @@ -96,7 +95,7 @@ if "__main__" == __name__: # explaining_cfg['adjust']['strategy']= 'rpns' # explaining_cfg['attack']['name']= 'all' explaining_cfg["cfg_dest"] = string_to_python( - f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}" + f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" ) # = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]" @@ -122,7 +121,7 @@ if "__main__" == __name__: # explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50] PATH = os.path.join( - explaining_folder + "/" + explaining_cfg["cfg_dest"] + ".yaml", + explaining_folder + "/" + explaining_cfg["cfg_dest"], ) write_yaml(explaining_cfg, PATH) # if os.path.exists(PATH):