diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index b26fc39..e7164ff 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -5,20 +5,19 @@ import glob import json import logging import os -from multiprocessing import Pool +from multiprocessing import Pool, cpu_count import torch -from explaining_framework.utils.io import read_yaml from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.utils.io import json_to_dict_list +from explaining_framework.utils.io import read_yaml + MODEL_STATE = "model_state" OPTIMIZER_STATE = "optimizer_state" SCHEDULER_STATE = "scheduler_state" -PARALEL = False - def _load_ckpt( model: torch.nn.Module, @@ -89,13 +88,9 @@ class LoadModelInfo(object): def list_xp(self): paths = [] all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml")) - if PARALEL: - with Pool(processes=len(all_file_paths)) as pool: - files = pool.map(self.load_cfg, all_file_paths) - else: - files = [] - for path in all_file_paths: - files.append(self.load_cfg(path)) + with Pool(cpu_count()) as pool: + files = pool.map(self.load_cfg, all_file_paths) + for file, path in zip(files, all_file_paths): dataset_name_ = file["dataset"]["name"] if self.dataset_name == dataset_name_: diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index e1d0a69..bfb156c 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -85,6 +85,11 @@ def obj_config_to_log(obj) -> str: logging.info(f"{k}={v}") +def notify_done(path): + with open(os.path.join(path, "done"), "w") as f: + f.write("") + + def set_printing(logger_path): """ Set up printing options diff --git a/main.py b/main.py index 04776db..fce377d 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust -from explaining_framework.utils.io import (dump_cfg, is_exists, +from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done, obj_config_to_str, read_json, set_printing, write_json) @@ -27,6 +27,7 @@ if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) for attack in outline.attacks: + done_attack_raw = True logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) @@ -40,6 +41,11 @@ if __name__ == "__main__": data_attack = outline.get_attack( attack=attack, item=item, path=data_attack_path ) + if data_attack is None: + done = False + continue + if done_attack_raw: + notify_done(attack_path) for attack in outline.attacks: logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) @@ -47,6 +53,8 @@ if __name__ == "__main__": "Running %s: %s" % (outline.explainer.__class__.__name__, outline.explaining_algorithm.name), ) + done_attack_exp = True + done_exp = True for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ): @@ -62,9 +70,13 @@ if __name__ == "__main__": attack=attack, item=item, path=data_attack_path_ ) if attack_data is None: + done_attack_exp = False continue + exp = outline.get_explanation(item=attack_data, path=data_attack_path_) + if exp is None: + done_exp = False continue else: for adjust in outline.adjusts: @@ -91,6 +103,7 @@ if __name__ == "__main__": item=exp_adjust, path=exp_masked_path ) if exp_masked is None: + done = False continue else: for metric in outline.metrics: @@ -104,5 +117,5 @@ if __name__ == "__main__": out_metric = outline.get_metric( metric=metric, item=exp_masked, path=metric_path ) - with open(os.path.join(outline.out_dir, "done"), "w") as f: - f.write("") + if done_exp and done_exp: + notify_done(attack_path_)