Fixing parallel issue

This commit is contained in:
araison 2023-01-16 01:59:22 +01:00
parent 45ea7e4255
commit 81c46d21c6
1 changed files with 10 additions and 4 deletions

View File

@ -8,16 +8,17 @@ import os
from multiprocessing import Pool from multiprocessing import Pool
import torch import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state" MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state" OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state" SCHEDULER_STATE = "scheduler_state"
PARALEL = False
def _load_ckpt( def _load_ckpt(
model: torch.nn.Module, model: torch.nn.Module,
@ -88,8 +89,13 @@ class LoadModelInfo(object):
def list_xp(self): def list_xp(self):
paths = [] paths = []
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml")) all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
with Pool(processes=len(all_file_paths)) as pool: if PARALEL:
files = pool.map(self.load_cfg, all_file_paths) with Pool(processes=len(all_file_paths)) as pool:
files = pool.map(self.load_cfg, all_file_paths)
else:
files = []
for path in all_file_paths:
file.append(self.load_cfg(path))
for file, path in zip(files, all_file_paths): for file, path in zip(files, all_file_paths):
dataset_name_ = file["dataset"]["name"] dataset_name_ = file["dataset"]["name"]
if self.dataset_name == dataset_name_: if self.dataset_name == dataset_name_: