Fixing parallel issue
This commit is contained in:
parent
45ea7e4255
commit
81c46d21c6
|
@ -8,16 +8,17 @@ import os
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
from explaining_framework.utils.io import read_yaml
|
|
||||||
|
|
||||||
MODEL_STATE = "model_state"
|
MODEL_STATE = "model_state"
|
||||||
OPTIMIZER_STATE = "optimizer_state"
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
SCHEDULER_STATE = "scheduler_state"
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
|
||||||
|
PARALEL = False
|
||||||
|
|
||||||
|
|
||||||
def _load_ckpt(
|
def _load_ckpt(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
@ -88,8 +89,13 @@ class LoadModelInfo(object):
|
||||||
def list_xp(self):
|
def list_xp(self):
|
||||||
paths = []
|
paths = []
|
||||||
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
||||||
with Pool(processes=len(all_file_paths)) as pool:
|
if PARALEL:
|
||||||
files = pool.map(self.load_cfg, all_file_paths)
|
with Pool(processes=len(all_file_paths)) as pool:
|
||||||
|
files = pool.map(self.load_cfg, all_file_paths)
|
||||||
|
else:
|
||||||
|
files = []
|
||||||
|
for path in all_file_paths:
|
||||||
|
file.append(self.load_cfg(path))
|
||||||
for file, path in zip(files, all_file_paths):
|
for file, path in zip(files, all_file_paths):
|
||||||
dataset_name_ = file["dataset"]["name"]
|
dataset_name_ = file["dataset"]["name"]
|
||||||
if self.dataset_name == dataset_name_:
|
if self.dataset_name == dataset_name_:
|
||||||
|
|
Loading…
Reference in New Issue