Adding multiprocessing feature for ckpt fetching

This commit is contained in:
araison 2023-01-11 16:30:13 +01:00
parent 3d0d3ec451
commit bf559657d8

View File

@ -5,6 +5,7 @@ import glob
import json
import logging
import os
from multiprocessing import Pool
import torch
from torch_geometric.graphgym.model_builder import create_model
@ -86,8 +87,10 @@ class LoadModelInfo(object):
def list_xp(self):
paths = []
for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")):
file = self.load_cfg(path)
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
with Pool(processes=len(all_file_paths)) as pool:
files = pool.map(self.load_cfg, all_file_paths)
for file, path in zip(files, all_file_paths):
dataset_name_ = file["dataset"]["name"]
if self.dataset_name == dataset_name_:
paths.append(os.path.dirname(path))