Raising exception when any model exists

This commit is contained in:
araison 2023-01-09 20:43:14 +01:00
parent dbf34d1679
commit ec0e3a440a
2 changed files with 24 additions and 12 deletions

View File

@ -2,9 +2,11 @@ import glob
import os import os
import shutil import shutil
from explaining_framework.utils.io import write_yaml from explaining_framework.utils.io import read_yaml, write_yaml
from torch_geometric.data.makedirs import makedirs from torch_geometric.data.makedirs import makedirs
from torch_geometric.graphgym.loader import load_pyg_dataset
# from torch_geometric.graphgym.loader import load_pyg_dataset
def chunkizing_list(l, n): def chunkizing_list(l, n):
@ -135,12 +137,21 @@ if "__main__" == __name__:
# continue # continue
# else: # else:
# write_yaml(explaining_cfg, PATH) # write_yaml(explaining_cfg, PATH)
configs = [path for path in glob.glob(os.path.join(explaining_folder, "*.yaml"))] configs = [
for index, config_chunk in enumerate( path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml"))
chunkizing_list(configs, int(len(configs) / 5)) ]
): for path in configs:
PATH_ = os.path.join(explaining_folder, f"gpu={index}") data = read_yaml(path)
makedirs(PATH_) data["model"][
for path in config_chunk: "path"
filename = os.path.basename(path) ] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
shutil.copy2(path, os.path.join(PATH_, filename)) write_yaml(data, path)
# for index, config_chunk in enumerate(
# chunkizing_list(configs, int(len(configs) / 5))
# ):
# PATH_ = os.path.join(explaining_folder, f"gpu={index}")
# makedirs(PATH_)
# for path in config_chunk:
# filename = os.path.basename(path)
# shutil.copy2(path, os.path.join(PATH_, filename))

View File

@ -7,11 +7,12 @@ import logging
import os import os
import torch import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state" MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state" OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state" SCHEDULER_STATE = "scheduler_state"