From ec0e3a440aefcafa3229ceca29c0c8fc91c2ff70 Mon Sep 17 00:00:00 2001 From: araison Date: Mon, 9 Jan 2023 20:43:14 +0100 Subject: [PATCH] Raising exception when any model exists --- explaining_framework/utils/config_gen.py | 33 ++++++++++++------- .../utils/explaining/load_ckpt.py | 3 +- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/explaining_framework/utils/config_gen.py b/explaining_framework/utils/config_gen.py index 1180383..ddf4fd0 100644 --- a/explaining_framework/utils/config_gen.py +++ b/explaining_framework/utils/config_gen.py @@ -2,9 +2,11 @@ import glob import os import shutil -from explaining_framework.utils.io import write_yaml +from explaining_framework.utils.io import read_yaml, write_yaml from torch_geometric.data.makedirs import makedirs -from torch_geometric.graphgym.loader import load_pyg_dataset + +# from torch_geometric.graphgym.loader import load_pyg_dataset + def chunkizing_list(l, n): @@ -135,12 +137,21 @@ if "__main__" == __name__: # continue # else: # write_yaml(explaining_cfg, PATH) - configs = [path for path in glob.glob(os.path.join(explaining_folder, "*.yaml"))] - for index, config_chunk in enumerate( - chunkizing_list(configs, int(len(configs) / 5)) - ): - PATH_ = os.path.join(explaining_folder, f"gpu={index}") - makedirs(PATH_) - for path in config_chunk: - filename = os.path.basename(path) - shutil.copy2(path, os.path.join(PATH_, filename)) + configs = [ + path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml")) + ] + for path in configs: + data = read_yaml(path) + data["model"][ + "path" + ] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" + write_yaml(data, path) + + # for index, config_chunk in enumerate( + # chunkizing_list(configs, int(len(configs) / 5)) + # ): + # PATH_ = os.path.join(explaining_folder, f"gpu={index}") + # makedirs(PATH_) + # for path in config_chunk: + # filename = os.path.basename(path) + # shutil.copy2(path, os.path.join(PATH_, filename)) diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index 3975819..078acdd 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -7,11 +7,12 @@ import logging import os import torch -from explaining_framework.utils.io import read_yaml from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.utils.io import json_to_dict_list +from explaining_framework.utils.io import read_yaml + MODEL_STATE = "model_state" OPTIMIZER_STATE = "optimizer_state" SCHEDULER_STATE = "scheduler_state"