Raising exception when any model exists
This commit is contained in:
parent
dbf34d1679
commit
ec0e3a440a
|
@ -2,9 +2,11 @@ import glob
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from explaining_framework.utils.io import write_yaml
|
||||
from explaining_framework.utils.io import read_yaml, write_yaml
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.graphgym.loader import load_pyg_dataset
|
||||
|
||||
# from torch_geometric.graphgym.loader import load_pyg_dataset
|
||||
|
||||
|
||||
|
||||
def chunkizing_list(l, n):
|
||||
|
@ -135,12 +137,21 @@ if "__main__" == __name__:
|
|||
# continue
|
||||
# else:
|
||||
# write_yaml(explaining_cfg, PATH)
|
||||
configs = [path for path in glob.glob(os.path.join(explaining_folder, "*.yaml"))]
|
||||
for index, config_chunk in enumerate(
|
||||
chunkizing_list(configs, int(len(configs) / 5))
|
||||
):
|
||||
PATH_ = os.path.join(explaining_folder, f"gpu={index}")
|
||||
makedirs(PATH_)
|
||||
for path in config_chunk:
|
||||
filename = os.path.basename(path)
|
||||
shutil.copy2(path, os.path.join(PATH_, filename))
|
||||
configs = [
|
||||
path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml"))
|
||||
]
|
||||
for path in configs:
|
||||
data = read_yaml(path)
|
||||
data["model"][
|
||||
"path"
|
||||
] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||
write_yaml(data, path)
|
||||
|
||||
# for index, config_chunk in enumerate(
|
||||
# chunkizing_list(configs, int(len(configs) / 5))
|
||||
# ):
|
||||
# PATH_ = os.path.join(explaining_folder, f"gpu={index}")
|
||||
# makedirs(PATH_)
|
||||
# for path in config_chunk:
|
||||
# filename = os.path.basename(path)
|
||||
# shutil.copy2(path, os.path.join(PATH_, filename))
|
||||
|
|
|
@ -7,11 +7,12 @@ import logging
|
|||
import os
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
|
Loading…
Reference in New Issue