diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index cb07ce7..5fd739e 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -8,16 +8,17 @@ import os from multiprocessing import Pool import torch +from explaining_framework.utils.io import read_yaml from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.utils.io import json_to_dict_list -from explaining_framework.utils.io import read_yaml - MODEL_STATE = "model_state" OPTIMIZER_STATE = "optimizer_state" SCHEDULER_STATE = "scheduler_state" +PARALEL = False + def _load_ckpt( model: torch.nn.Module, @@ -88,8 +89,13 @@ class LoadModelInfo(object): def list_xp(self): paths = [] all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml")) - with Pool(processes=len(all_file_paths)) as pool: - files = pool.map(self.load_cfg, all_file_paths) + if PARALEL: + with Pool(processes=len(all_file_paths)) as pool: + files = pool.map(self.load_cfg, all_file_paths) + else: + files = [] + for path in all_file_paths: + file.append(self.load_cfg(path)) for file, path in zip(files, all_file_paths): dataset_name_ = file["dataset"]["name"] if self.dataset_name == dataset_name_: