Fixing
This commit is contained in:
parent
02d994b68b
commit
a97ae8101c
|
@ -2,12 +2,11 @@ import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
from explaining_framework.utils.io import read_yaml, write_yaml
|
||||||
from torch_geometric.data.makedirs import makedirs
|
from torch_geometric.data.makedirs import makedirs
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
from torch_geometric.graphgym.utils.io import string_to_python
|
from torch_geometric.graphgym.utils.io import string_to_python
|
||||||
|
|
||||||
from explaining_framework.utils.io import read_yaml, write_yaml
|
|
||||||
|
|
||||||
if "__main__" == __name__:
|
if "__main__" == __name__:
|
||||||
config_folder = os.path.abspath(
|
config_folder = os.path.abspath(
|
||||||
os.path.join(os.path.dirname(__name__), "../../", "configs")
|
os.path.join(os.path.dirname(__name__), "../../", "configs")
|
||||||
|
@ -96,7 +95,7 @@ if "__main__" == __name__:
|
||||||
# explaining_cfg['adjust']['strategy']= 'rpns'
|
# explaining_cfg['adjust']['strategy']= 'rpns'
|
||||||
# explaining_cfg['attack']['name']= 'all'
|
# explaining_cfg['attack']['name']= 'all'
|
||||||
explaining_cfg["cfg_dest"] = string_to_python(
|
explaining_cfg["cfg_dest"] = string_to_python(
|
||||||
f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}"
|
f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
|
||||||
)
|
)
|
||||||
# = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]"
|
# = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]"
|
||||||
|
|
||||||
|
@ -122,7 +121,7 @@ if "__main__" == __name__:
|
||||||
# explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
# explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
||||||
# explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50]
|
# explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50]
|
||||||
PATH = os.path.join(
|
PATH = os.path.join(
|
||||||
explaining_folder + "/" + explaining_cfg["cfg_dest"] + ".yaml",
|
explaining_folder + "/" + explaining_cfg["cfg_dest"],
|
||||||
)
|
)
|
||||||
write_yaml(explaining_cfg, PATH)
|
write_yaml(explaining_cfg, PATH)
|
||||||
# if os.path.exists(PATH):
|
# if os.path.exists(PATH):
|
||||||
|
|
Loading…
Reference in New Issue