Adding multiprocessing feature for ckpt fetching
This commit is contained in:
parent
3d0d3ec451
commit
bf559657d8
|
@ -5,6 +5,7 @@ import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
|
@ -86,8 +87,10 @@ class LoadModelInfo(object):
|
||||||
|
|
||||||
def list_xp(self):
|
def list_xp(self):
|
||||||
paths = []
|
paths = []
|
||||||
for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")):
|
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
||||||
file = self.load_cfg(path)
|
with Pool(processes=len(all_file_paths)) as pool:
|
||||||
|
files = pool.map(self.load_cfg, all_file_paths)
|
||||||
|
for file, path in zip(files, all_file_paths):
|
||||||
dataset_name_ = file["dataset"]["name"]
|
dataset_name_ = file["dataset"]["name"]
|
||||||
if self.dataset_name == dataset_name_:
|
if self.dataset_name == dataset_name_:
|
||||||
paths.append(os.path.dirname(path))
|
paths.append(os.path.dirname(path))
|
||||||
|
|
Loading…
Reference in New Issue