Fixing some minor bugs

This commit is contained in:
araison 2023-01-31 09:47:57 +01:00
parent 1779da4757
commit 56a62df848
3 changed files with 27 additions and 14 deletions

View File

@ -5,20 +5,19 @@ import glob
import json import json
import logging import logging
import os import os
from multiprocessing import Pool from multiprocessing import Pool, cpu_count
import torch import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state" MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state" OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state" SCHEDULER_STATE = "scheduler_state"
PARALEL = False
def _load_ckpt( def _load_ckpt(
model: torch.nn.Module, model: torch.nn.Module,
@ -89,13 +88,9 @@ class LoadModelInfo(object):
def list_xp(self): def list_xp(self):
paths = [] paths = []
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml")) all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
if PARALEL: with Pool(cpu_count()) as pool:
with Pool(processes=len(all_file_paths)) as pool: files = pool.map(self.load_cfg, all_file_paths)
files = pool.map(self.load_cfg, all_file_paths)
else:
files = []
for path in all_file_paths:
files.append(self.load_cfg(path))
for file, path in zip(files, all_file_paths): for file, path in zip(files, all_file_paths):
dataset_name_ = file["dataset"]["name"] dataset_name_ = file["dataset"]["name"]
if self.dataset_name == dataset_name_: if self.dataset_name == dataset_name_:

View File

@ -85,6 +85,11 @@ def obj_config_to_log(obj) -> str:
logging.info(f"{k}={v}") logging.info(f"{k}={v}")
def notify_done(path):
with open(os.path.join(path, "done"), "w") as f:
f.write("")
def set_printing(logger_path): def set_printing(logger_path):
""" """
Set up printing options Set up printing options

19
main.py
View File

@ -19,7 +19,7 @@ from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.io import (dump_cfg, is_exists, from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done,
obj_config_to_str, read_json, obj_config_to_str, read_json,
set_printing, write_json) set_printing, write_json)
@ -27,6 +27,7 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
for attack in outline.attacks: for attack in outline.attacks:
done_attack_raw = True
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
for item, index in tqdm( for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset) zip(outline.dataset, outline.indexes), total=len(outline.dataset)
@ -40,6 +41,11 @@ if __name__ == "__main__":
data_attack = outline.get_attack( data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path attack=attack, item=item, path=data_attack_path
) )
if data_attack is None:
done = False
continue
if done_attack_raw:
notify_done(attack_path)
for attack in outline.attacks: for attack in outline.attacks:
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
@ -47,6 +53,8 @@ if __name__ == "__main__":
"Running %s: %s" "Running %s: %s"
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name), % (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
) )
done_attack_exp = True
done_exp = True
for item, index in tqdm( for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset) zip(outline.dataset, outline.indexes), total=len(outline.dataset)
): ):
@ -62,9 +70,13 @@ if __name__ == "__main__":
attack=attack, item=item, path=data_attack_path_ attack=attack, item=item, path=data_attack_path_
) )
if attack_data is None: if attack_data is None:
done_attack_exp = False
continue continue
exp = outline.get_explanation(item=attack_data, path=data_attack_path_) exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
if exp is None: if exp is None:
done_exp = False
continue continue
else: else:
for adjust in outline.adjusts: for adjust in outline.adjusts:
@ -91,6 +103,7 @@ if __name__ == "__main__":
item=exp_adjust, path=exp_masked_path item=exp_adjust, path=exp_masked_path
) )
if exp_masked is None: if exp_masked is None:
done = False
continue continue
else: else:
for metric in outline.metrics: for metric in outline.metrics:
@ -104,5 +117,5 @@ if __name__ == "__main__":
out_metric = outline.get_metric( out_metric = outline.get_metric(
metric=metric, item=exp_masked, path=metric_path metric=metric, item=exp_masked, path=metric_path
) )
with open(os.path.join(outline.out_dir, "done"), "w") as f: if done_exp and done_exp:
f.write("") notify_done(attack_path_)