Raising exception when any model exists

This commit is contained in:
araison 2023-01-09 20:43:14 +01:00
parent dbf34d1679
commit ec0e3a440a
2 changed files with 24 additions and 12 deletions

View File

@ -2,9 +2,11 @@ import glob
import os
import shutil
from explaining_framework.utils.io import write_yaml
from explaining_framework.utils.io import read_yaml, write_yaml
from torch_geometric.data.makedirs import makedirs
from torch_geometric.graphgym.loader import load_pyg_dataset
# from torch_geometric.graphgym.loader import load_pyg_dataset
def chunkizing_list(l, n):
@ -135,12 +137,21 @@ if "__main__" == __name__:
# continue
# else:
# write_yaml(explaining_cfg, PATH)
configs = [path for path in glob.glob(os.path.join(explaining_folder, "*.yaml"))]
for index, config_chunk in enumerate(
chunkizing_list(configs, int(len(configs) / 5))
):
PATH_ = os.path.join(explaining_folder, f"gpu={index}")
makedirs(PATH_)
for path in config_chunk:
filename = os.path.basename(path)
shutil.copy2(path, os.path.join(PATH_, filename))
configs = [
path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml"))
]
for path in configs:
data = read_yaml(path)
data["model"][
"path"
] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
write_yaml(data, path)
# for index, config_chunk in enumerate(
# chunkizing_list(configs, int(len(configs) / 5))
# ):
# PATH_ = os.path.join(explaining_folder, f"gpu={index}")
# makedirs(PATH_)
# for path in config_chunk:
# filename = os.path.basename(path)
# shutil.copy2(path, os.path.join(PATH_, filename))

View File

@ -7,11 +7,12 @@ import logging
import os
import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state"