Adding multiprocessing feature for ckpt fetching
This commit is contained in:
parent
3d0d3ec451
commit
bf559657d8
|
@ -5,6 +5,7 @@ import glob
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
|
@ -86,8 +87,10 @@ class LoadModelInfo(object):
|
|||
|
||||
def list_xp(self):
|
||||
paths = []
|
||||
for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")):
|
||||
file = self.load_cfg(path)
|
||||
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
||||
with Pool(processes=len(all_file_paths)) as pool:
|
||||
files = pool.map(self.load_cfg, all_file_paths)
|
||||
for file, path in zip(files, all_file_paths):
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if self.dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
|
|
Loading…
Reference in New Issue