Fixing some minor bugs
This commit is contained in:
parent
1779da4757
commit
56a62df848
|
@ -5,20 +5,19 @@ import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from explaining_framework.utils.io import read_yaml
|
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
|
|
||||||
MODEL_STATE = "model_state"
|
MODEL_STATE = "model_state"
|
||||||
OPTIMIZER_STATE = "optimizer_state"
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
SCHEDULER_STATE = "scheduler_state"
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
|
||||||
PARALEL = False
|
|
||||||
|
|
||||||
|
|
||||||
def _load_ckpt(
|
def _load_ckpt(
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
@ -89,13 +88,9 @@ class LoadModelInfo(object):
|
||||||
def list_xp(self):
|
def list_xp(self):
|
||||||
paths = []
|
paths = []
|
||||||
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
all_file_paths = glob.glob(os.path.join(self.model_dir, "**", "config.yaml"))
|
||||||
if PARALEL:
|
with Pool(cpu_count()) as pool:
|
||||||
with Pool(processes=len(all_file_paths)) as pool:
|
files = pool.map(self.load_cfg, all_file_paths)
|
||||||
files = pool.map(self.load_cfg, all_file_paths)
|
|
||||||
else:
|
|
||||||
files = []
|
|
||||||
for path in all_file_paths:
|
|
||||||
files.append(self.load_cfg(path))
|
|
||||||
for file, path in zip(files, all_file_paths):
|
for file, path in zip(files, all_file_paths):
|
||||||
dataset_name_ = file["dataset"]["name"]
|
dataset_name_ = file["dataset"]["name"]
|
||||||
if self.dataset_name == dataset_name_:
|
if self.dataset_name == dataset_name_:
|
||||||
|
|
|
@ -85,6 +85,11 @@ def obj_config_to_log(obj) -> str:
|
||||||
logging.info(f"{k}={v}")
|
logging.info(f"{k}={v}")
|
||||||
|
|
||||||
|
|
||||||
|
def notify_done(path):
|
||||||
|
with open(os.path.join(path, "done"), "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
|
||||||
def set_printing(logger_path):
|
def set_printing(logger_path):
|
||||||
"""
|
"""
|
||||||
Set up printing options
|
Set up printing options
|
||||||
|
|
19
main.py
19
main.py
|
@ -19,7 +19,7 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done,
|
||||||
obj_config_to_str, read_json,
|
obj_config_to_str, read_json,
|
||||||
set_printing, write_json)
|
set_printing, write_json)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
|
done_attack_raw = True
|
||||||
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||||
for item, index in tqdm(
|
for item, index in tqdm(
|
||||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
|
@ -40,6 +41,11 @@ if __name__ == "__main__":
|
||||||
data_attack = outline.get_attack(
|
data_attack = outline.get_attack(
|
||||||
attack=attack, item=item, path=data_attack_path
|
attack=attack, item=item, path=data_attack_path
|
||||||
)
|
)
|
||||||
|
if data_attack is None:
|
||||||
|
done = False
|
||||||
|
continue
|
||||||
|
if done_attack_raw:
|
||||||
|
notify_done(attack_path)
|
||||||
|
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||||
|
@ -47,6 +53,8 @@ if __name__ == "__main__":
|
||||||
"Running %s: %s"
|
"Running %s: %s"
|
||||||
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
|
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
|
||||||
)
|
)
|
||||||
|
done_attack_exp = True
|
||||||
|
done_exp = True
|
||||||
for item, index in tqdm(
|
for item, index in tqdm(
|
||||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
):
|
):
|
||||||
|
@ -62,9 +70,13 @@ if __name__ == "__main__":
|
||||||
attack=attack, item=item, path=data_attack_path_
|
attack=attack, item=item, path=data_attack_path_
|
||||||
)
|
)
|
||||||
if attack_data is None:
|
if attack_data is None:
|
||||||
|
done_attack_exp = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||||
|
|
||||||
if exp is None:
|
if exp is None:
|
||||||
|
done_exp = False
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
for adjust in outline.adjusts:
|
for adjust in outline.adjusts:
|
||||||
|
@ -91,6 +103,7 @@ if __name__ == "__main__":
|
||||||
item=exp_adjust, path=exp_masked_path
|
item=exp_adjust, path=exp_masked_path
|
||||||
)
|
)
|
||||||
if exp_masked is None:
|
if exp_masked is None:
|
||||||
|
done = False
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
for metric in outline.metrics:
|
for metric in outline.metrics:
|
||||||
|
@ -104,5 +117,5 @@ if __name__ == "__main__":
|
||||||
out_metric = outline.get_metric(
|
out_metric = outline.get_metric(
|
||||||
metric=metric, item=exp_masked, path=metric_path
|
metric=metric, item=exp_masked, path=metric_path
|
||||||
)
|
)
|
||||||
with open(os.path.join(outline.out_dir, "done"), "w") as f:
|
if done_exp and done_exp:
|
||||||
f.write("")
|
notify_done(attack_path_)
|
||||||
|
|
Loading…
Reference in New Issue