Fixing some minor bugs

This commit is contained in:
araison 2023-01-31 09:47:57 +01:00
parent 1779da4757
commit 56a62df848
3 changed files with 27 additions and 14 deletions

View File

@ -5,20 +5,19 @@ import glob
import json
import logging
import os
from multiprocessing import Pool
from multiprocessing import Pool, cpu_count
import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state"
PARALEL = False
def _load_ckpt(
model: torch.nn.Module,
@ -89,13 +88,9 @@ class LoadModelInfo(object):
def list_xp(self):
paths = []
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
if PARALEL:
with Pool(processes=len(all_file_paths)) as pool:
files = pool.map(self.load_cfg, all_file_paths)
else:
files = []
for path in all_file_paths:
files.append(self.load_cfg(path))
with Pool(cpu_count()) as pool:
files = pool.map(self.load_cfg, all_file_paths)
for file, path in zip(files, all_file_paths):
dataset_name_ = file["dataset"]["name"]
if self.dataset_name == dataset_name_:

View File

@ -85,6 +85,11 @@ def obj_config_to_log(obj) -> str:
logging.info(f"{k}={v}")
def notify_done(path):
with open(os.path.join(path, "done"), "w") as f:
f.write("")
def set_printing(logger_path):
"""
Set up printing options

19
main.py
View File

@ -19,7 +19,7 @@ from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.io import (dump_cfg, is_exists,
from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done,
obj_config_to_str, read_json,
set_printing, write_json)
@ -27,6 +27,7 @@ if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
for attack in outline.attacks:
done_attack_raw = True
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
@ -40,6 +41,11 @@ if __name__ == "__main__":
data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path
)
if data_attack is None:
done = False
continue
if done_attack_raw:
notify_done(attack_path)
for attack in outline.attacks:
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
@ -47,6 +53,8 @@ if __name__ == "__main__":
"Running %s: %s"
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
)
done_attack_exp = True
done_exp = True
for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
):
@ -62,9 +70,13 @@ if __name__ == "__main__":
attack=attack, item=item, path=data_attack_path_
)
if attack_data is None:
done_attack_exp = False
continue
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
if exp is None:
done_exp = False
continue
else:
for adjust in outline.adjusts:
@ -91,6 +103,7 @@ if __name__ == "__main__":
item=exp_adjust, path=exp_masked_path
)
if exp_masked is None:
done = False
continue
else:
for metric in outline.metrics:
@ -104,5 +117,5 @@ if __name__ == "__main__":
out_metric = outline.get_metric(
metric=metric, item=exp_masked, path=metric_path
)
with open(os.path.join(outline.out_dir, "done"), "w") as f:
f.write("")
if done_exp and done_exp:
notify_done(attack_path_)