Fixing parallel issue
This commit is contained in:
parent
45ea7e4255
commit
81c46d21c6
|
@ -8,16 +8,17 @@ import os
|
|||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
||||
PARALEL = False
|
||||
|
||||
|
||||
def _load_ckpt(
|
||||
model: torch.nn.Module,
|
||||
|
@ -88,8 +89,13 @@ class LoadModelInfo(object):
|
|||
def list_xp(self):
|
||||
paths = []
|
||||
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
||||
with Pool(processes=len(all_file_paths)) as pool:
|
||||
files = pool.map(self.load_cfg, all_file_paths)
|
||||
if PARALEL:
|
||||
with Pool(processes=len(all_file_paths)) as pool:
|
||||
files = pool.map(self.load_cfg, all_file_paths)
|
||||
else:
|
||||
files = []
|
||||
for path in all_file_paths:
|
||||
file.append(self.load_cfg(path))
|
||||
for file, path in zip(files, all_file_paths):
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if self.dataset_name == dataset_name_:
|
||||
|
|
Loading…
Reference in New Issue