diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index 078acdd..cb07ce7 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -5,6 +5,7 @@ import glob import json import logging import os +from multiprocessing import Pool import torch from torch_geometric.graphgym.model_builder import create_model @@ -86,8 +87,10 @@ class LoadModelInfo(object): def list_xp(self): paths = [] - for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")): - file = self.load_cfg(path) + all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml")) + with Pool(processes=len(all_file_paths)) as pool: + files = pool.map(self.load_cfg, all_file_paths) + for file, path in zip(files, all_file_paths): dataset_name_ = file["dataset"]["name"] if self.dataset_name == dataset_name_: paths.append(os.path.dirname(path))