This commit is contained in:
araison 2023-01-10 12:24:58 +01:00
parent 02d994b68b
commit a97ae8101c
1 changed files with 3 additions and 4 deletions

View File

@ -2,12 +2,11 @@ import glob
import os import os
import shutil import shutil
from explaining_framework.utils.io import read_yaml, write_yaml
from torch_geometric.data.makedirs import makedirs from torch_geometric.data.makedirs import makedirs
from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.utils.io import string_to_python from torch_geometric.graphgym.utils.io import string_to_python
from explaining_framework.utils.io import read_yaml, write_yaml
if "__main__" == __name__: if "__main__" == __name__:
config_folder = os.path.abspath( config_folder = os.path.abspath(
os.path.join(os.path.dirname(__name__), "../../", "configs") os.path.join(os.path.dirname(__name__), "../../", "configs")
@ -96,7 +95,7 @@ if "__main__" == __name__:
# explaining_cfg['adjust']['strategy']= 'rpns' # explaining_cfg['adjust']['strategy']= 'rpns'
# explaining_cfg['attack']['name']= 'all' # explaining_cfg['attack']['name']= 'all'
explaining_cfg["cfg_dest"] = string_to_python( explaining_cfg["cfg_dest"] = string_to_python(
f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}" f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
) )
# = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]" # = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]"
@ -122,7 +121,7 @@ if "__main__" == __name__:
# explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] # explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50] # explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50]
PATH = os.path.join( PATH = os.path.join(
explaining_folder + "/" + explaining_cfg["cfg_dest"] + ".yaml", explaining_folder + "/" + explaining_cfg["cfg_dest"],
) )
write_yaml(explaining_cfg, PATH) write_yaml(explaining_cfg, PATH)
# if os.path.exists(PATH): # if os.path.exists(PATH):