Raising exception when any model exists
This commit is contained in:
parent
dbf34d1679
commit
ec0e3a440a
|
@ -2,9 +2,11 @@ import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from explaining_framework.utils.io import write_yaml
|
from explaining_framework.utils.io import read_yaml, write_yaml
|
||||||
from torch_geometric.data.makedirs import makedirs
|
from torch_geometric.data.makedirs import makedirs
|
||||||
from torch_geometric.graphgym.loader import load_pyg_dataset
|
|
||||||
|
# from torch_geometric.graphgym.loader import load_pyg_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def chunkizing_list(l, n):
|
def chunkizing_list(l, n):
|
||||||
|
@ -135,12 +137,21 @@ if "__main__" == __name__:
|
||||||
# continue
|
# continue
|
||||||
# else:
|
# else:
|
||||||
# write_yaml(explaining_cfg, PATH)
|
# write_yaml(explaining_cfg, PATH)
|
||||||
configs = [path for path in glob.glob(os.path.join(explaining_folder, "*.yaml"))]
|
configs = [
|
||||||
for index, config_chunk in enumerate(
|
path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml"))
|
||||||
chunkizing_list(configs, int(len(configs) / 5))
|
]
|
||||||
):
|
for path in configs:
|
||||||
PATH_ = os.path.join(explaining_folder, f"gpu={index}")
|
data = read_yaml(path)
|
||||||
makedirs(PATH_)
|
data["model"][
|
||||||
for path in config_chunk:
|
"path"
|
||||||
filename = os.path.basename(path)
|
] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||||
shutil.copy2(path, os.path.join(PATH_, filename))
|
write_yaml(data, path)
|
||||||
|
|
||||||
|
# for index, config_chunk in enumerate(
|
||||||
|
# chunkizing_list(configs, int(len(configs) / 5))
|
||||||
|
# ):
|
||||||
|
# PATH_ = os.path.join(explaining_folder, f"gpu={index}")
|
||||||
|
# makedirs(PATH_)
|
||||||
|
# for path in config_chunk:
|
||||||
|
# filename = os.path.basename(path)
|
||||||
|
# shutil.copy2(path, os.path.join(PATH_, filename))
|
||||||
|
|
|
@ -7,11 +7,12 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from explaining_framework.utils.io import read_yaml
|
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
|
|
||||||
MODEL_STATE = "model_state"
|
MODEL_STATE = "model_state"
|
||||||
OPTIMIZER_STATE = "optimizer_state"
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
SCHEDULER_STATE = "scheduler_state"
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
|
Loading…
Reference in New Issue