Fixing some minor bugs
This commit is contained in:
parent
1779da4757
commit
56a62df848
|
@ -5,20 +5,19 @@ import glob
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
||||
PARALEL = False
|
||||
|
||||
|
||||
def _load_ckpt(
|
||||
model: torch.nn.Module,
|
||||
|
@ -89,13 +88,9 @@ class LoadModelInfo(object):
|
|||
def list_xp(self):
|
||||
paths = []
|
||||
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
||||
if PARALEL:
|
||||
with Pool(processes=len(all_file_paths)) as pool:
|
||||
with Pool(cpu_count()) as pool:
|
||||
files = pool.map(self.load_cfg, all_file_paths)
|
||||
else:
|
||||
files = []
|
||||
for path in all_file_paths:
|
||||
files.append(self.load_cfg(path))
|
||||
|
||||
for file, path in zip(files, all_file_paths):
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if self.dataset_name == dataset_name_:
|
||||
|
|
|
@ -85,6 +85,11 @@ def obj_config_to_log(obj) -> str:
|
|||
logging.info(f"{k}={v}")
|
||||
|
||||
|
||||
def notify_done(path):
|
||||
with open(os.path.join(path, "done"), "w") as f:
|
||||
f.write("")
|
||||
|
||||
|
||||
def set_printing(logger_path):
|
||||
"""
|
||||
Set up printing options
|
||||
|
|
19
main.py
19
main.py
|
@ -19,7 +19,7 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
|||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||
from explaining_framework.utils.explanation.adjust import Adjust
|
||||
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||
from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done,
|
||||
obj_config_to_str, read_json,
|
||||
set_printing, write_json)
|
||||
|
||||
|
@ -27,6 +27,7 @@ if __name__ == "__main__":
|
|||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||
for attack in outline.attacks:
|
||||
done_attack_raw = True
|
||||
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||
for item, index in tqdm(
|
||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||
|
@ -40,6 +41,11 @@ if __name__ == "__main__":
|
|||
data_attack = outline.get_attack(
|
||||
attack=attack, item=item, path=data_attack_path
|
||||
)
|
||||
if data_attack is None:
|
||||
done = False
|
||||
continue
|
||||
if done_attack_raw:
|
||||
notify_done(attack_path)
|
||||
|
||||
for attack in outline.attacks:
|
||||
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||
|
@ -47,6 +53,8 @@ if __name__ == "__main__":
|
|||
"Running %s: %s"
|
||||
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
|
||||
)
|
||||
done_attack_exp = True
|
||||
done_exp = True
|
||||
for item, index in tqdm(
|
||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||
):
|
||||
|
@ -62,9 +70,13 @@ if __name__ == "__main__":
|
|||
attack=attack, item=item, path=data_attack_path_
|
||||
)
|
||||
if attack_data is None:
|
||||
done_attack_exp = False
|
||||
continue
|
||||
|
||||
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||
|
||||
if exp is None:
|
||||
done_exp = False
|
||||
continue
|
||||
else:
|
||||
for adjust in outline.adjusts:
|
||||
|
@ -91,6 +103,7 @@ if __name__ == "__main__":
|
|||
item=exp_adjust, path=exp_masked_path
|
||||
)
|
||||
if exp_masked is None:
|
||||
done = False
|
||||
continue
|
||||
else:
|
||||
for metric in outline.metrics:
|
||||
|
@ -104,5 +117,5 @@ if __name__ == "__main__":
|
|||
out_metric = outline.get_metric(
|
||||
metric=metric, item=exp_masked, path=metric_path
|
||||
)
|
||||
with open(os.path.join(outline.out_dir, "done"), "w") as f:
|
||||
f.write("")
|
||||
if done_exp and done_exp:
|
||||
notify_done(attack_path_)
|
||||
|
|
Loading…
Reference in New Issue