From bf559657d80c51d51c5f933bae85954ad3b1b13f Mon Sep 17 00:00:00 2001 From: araison Date: Wed, 11 Jan 2023 16:30:13 +0100 Subject: [PATCH] Adding multiprocessing feature for ckpt fetching --- explaining_framework/utils/explaining/load_ckpt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index 078acdd..cb07ce7 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -5,6 +5,7 @@ import glob import json import logging import os +from multiprocessing import Pool import torch from torch_geometric.graphgym.model_builder import create_model @@ -86,8 +87,10 @@ class LoadModelInfo(object): def list_xp(self): paths = [] - for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")): - file = self.load_cfg(path) + all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml")) + with Pool(processes=len(all_file_paths)) as pool: + files = pool.map(self.load_cfg, all_file_paths) + for file, path in zip(files, all_file_paths): dataset_name_ = file["dataset"]["name"] if self.dataset_name == dataset_name_: paths.append(os.path.dirname(path))