Reformating, adding logging and progress bar

This commit is contained in:
araison 2023-01-08 23:19:31 +01:00
parent fbc685503c
commit 02bdbdc6ca
4 changed files with 227 additions and 75 deletions

View File

@ -1,11 +1,15 @@
import copy import copy
import datetime
import itertools import itertools
import logging
import os
from typing import Any from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
@ -35,9 +39,12 @@ from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.explanation.io import ( from explaining_framework.utils.explanation.io import (
_get_explanation, _load_explanation, _save_explanation, _get_explanation, _load_explanation, _save_explanation,
explanation_verification, get_pred) explanation_verification, get_pred, is_empty_graph)
from explaining_framework.utils.io import (is_exists, obj_config_to_str, from explaining_framework.utils.io import (dump_cfg, is_exists,
read_json, write_json, write_yaml) obj_config_to_log,
obj_config_to_str, read_json,
set_printing, write_json,
write_yaml)
all__captum = [ all__captum = [
"LRP", "LRP",
@ -155,6 +162,7 @@ class ExplainingOutline(object):
self.load_adjust() self.load_adjust()
self.load_threshold() self.load_threshold()
self.load_graphstat() self.load_graphstat()
self.setup_experiment()
seed_everything(self.explaining_cfg.seed) seed_everything(self.explaining_cfg.seed)
@ -168,7 +176,8 @@ class ExplainingOutline(object):
self.load_dataset() self.load_dataset()
try: try:
item = next(self.dataset) item = next(self.dataset)
item = item.to(cfg.accelerator) device = self.cfg.accelerator
item = item.to(device)
return item return item
except StopIteration: except StopIteration:
return None return None
@ -555,49 +564,118 @@ class ExplainingOutline(object):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
explanation = _get_explanation(self.explainer, item) explanation = _get_explanation(self.explainer, item)
if explanation is None:
logging.warning(
" EXP || Generated; Path %s; FAILED",
(path),
)
else:
logging.debug(
"EXP || Generated; Path %s; SUCCEEDED",
(path),
)
else: else:
explanation = _load_explanation(path) explanation = _load_explanation(path)
logging.debug(
"EXP || Loaded; Path %s; SUCCEEDED",
(path),
)
explanation = explanation.to(self.cfg.accelerator) explanation = explanation.to(self.cfg.accelerator)
else: else:
explanation = _get_explanation(self.explainer, item) explanation = _get_explanation(self.explainer, item)
get_pred(self.explainer, explanation) get_pred(self.explainer, explanation)
_save_explanation(explanation, path) _save_explanation(explanation, path)
logging.debug(
"EXP || Generated; Path %s; SUCCEEDED",
(path),
)
return explanation return explanation
def get_adjust(self, adjust: Adjust, item: Explanation, path: str): def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item) exp_adjust = adjust.forward(item)
logging.debug(
"ADJUST || Generated; Path %s; SUCCEEDED",
(path),
)
else: else:
exp_adjust = _load_explanation(path) exp_adjust = _load_explanation(path)
logging.debug(
"ADJUST || Loaded; Path %s; SUCCEEDED",
(path),
)
else: else:
exp_adjust = adjust.forward(item) exp_adjust = adjust.forward(item)
get_pred(self.explainer, exp_adjust) get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path) _save_explanation(exp_adjust, path)
logging.debug(
"ADJUST || Generated; Path %s; SUCCEEDED",
(path),
)
return exp_adjust return exp_adjust
def get_threshold(self, item: Explanation, path: str): def get_threshold(self, item: Explanation, path: str):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item) exp_threshold = self.explainer._post_process(item)
logging.debug(
"THRESHOLD || Generated; Path %s; SUCCEEDED",
(path),
)
else: else:
exp_threshold = _load_explanation(path) exp_threshold = _load_explanation(path)
logging.debug(
"THRESHOLD || Loaded; Path %s; SUCCEEDED",
(path),
)
else: else:
exp_threshold = self.explainer._post_process(item) exp_threshold = self.explainer._post_process(item)
get_pred(self.explainer, exp_threshold) get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path) _save_explanation(exp_threshold, path)
logging.debug(
"THRESHOLD || Generated; Path %s; SUCCEEDED",
(path),
)
if is_empty_graph(exp_threshold):
logging.warning(
"THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED",
(path),
)
return None
return exp_threshold return exp_threshold
def get_metric(self, metric: Metric, item: Explanation, path: str): def get_metric(self, metric: Metric, item: Explanation, path: str):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item) out_metric = metric.forward(item)
logging.debug(
"METRIC || Generated; Path %s; SUCCEEDED",
(path),
)
else: else:
out_metric = read_json(path) out_metric = read_json(path)
logging.debug(
"METRIC || Loaded; Path %s; SUCCEEDED",
(path),
)
else: else:
out_metric = metric.forward(item) out_metric = metric.forward(item)
data = {f"{metric.name}": out_metric} data = {f"{metric.name}": out_metric}
write_json(data, path) write_json(data, path)
if out_metric is None:
logging.debug(
"METRIC || Generated; Path %s; FAILED",
(path),
)
else:
logging.debug(
"METRIC || Generated; Path %s; SUCCEEDED",
(path),
)
return out_metric return out_metric
def get_stat(self, item: Data, path: str): def get_stat(self, item: Data, path: str):
@ -614,9 +692,66 @@ class ExplainingOutline(object):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
data_attack = attack.get_attacked_prediction(item) data_attack = attack.get_attacked_prediction(item)
logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED",
(path),
)
else: else:
data_attack = _load_explanation(path) data_attack = _load_explanation(path)
logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED",
(path),
)
else: else:
data_attack = attack.get_attacked_prediction(item) data_attack = attack.get_attacked_prediction(item)
_save_explanation(data_attack, path) _save_explanation(data_attack, path)
logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED",
(path),
)
return data_attack return data_attack
def setup_experiment(self):
now = datetime.datetime.now()
self.out_dir = os.path.join(
self.explaining_cfg.out_dir,
self.cfg.dataset.name,
self.model_signature,
)
makedirs(self.out_dir)
now_str = now.strftime("month=%m-day=%d-year=%Y-hour=%H-minute=%M-second=%S")
set_printing(f"{self.out_dir}/logging-{now_str}.log")
dump_cfg(self.cfg, os.path.join(self.out_dir, "config.yaml"))
write_json(self.model_info, os.path.join(self.out_dir, "info.json"))
self.explainer_path = os.path.join(
self.out_dir,
self.explaining_cfg.explainer.name,
obj_config_to_str(self.explaining_algorithm),
)
makedirs(self.explainer_path)
dump_cfg(
self.explainer_cfg,
os.path.join(self.explainer_path, "explainer_cfg.yaml"),
)
dump_cfg(
self.explaining_cfg,
os.path.join(self.explainer_path, self.explaining_cfg.cfg_dest),
)
logging.info("Setting up experiment")
logging.info("Date and Time: %s", now)
logging.info("Save experiment to %s", self.out_dir)
logging.info(self.cfg)
logging.info(self.explaining_cfg)
logging.info(self.explainer_cfg)
logging.info(self.model)
logging.info(obj_config_to_log(self.model_info))
for metric in self.metrics + self.attacks:
logging.info(obj_config_to_str(metric))
for threshold_conf in self.thresholds_configs:
logging.info(obj_config_to_str(threshold_conf))
logging.info("Proceeding to explanations..")

View File

@ -16,7 +16,6 @@ def _get_explanation(explainer, item):
target=item.y, target=item.y,
) )
if not explanation_verification(explanation): if not explanation_verification(explanation):
# WARNING + LOG
return None return None
else: else:
explanation = explanation.to(cfg.accelerator) explanation = explanation.to(cfg.accelerator)
@ -50,6 +49,8 @@ def explanation_verification(exp: Explanation) -> bool:
is_nan = mask.isnan().any().item() is_nan = mask.isnan().any().item()
is_inf = mask.isinf().any().item() is_inf = mask.isinf().any().item()
is_ok = exp.validate() is_ok = exp.validate()
is_const = mask.max() == mask.min()
if is_nan or is_inf or not is_ok: if is_nan or is_inf or not is_ok:
is_good = False is_good = False
return is_good return is_good

View File

@ -1,8 +1,12 @@
import json import json
import logging
import os import os
import sys
import yaml import yaml
from explaining_framework.config.explaining_config import explaining_cfg
def read_json(path: str) -> dict: def read_json(path: str) -> dict:
with open(path, "r") as f: with open(path, "r") as f:
@ -68,3 +72,37 @@ def obj_config_to_str(obj) -> str:
else: else:
config = get_dict_config(obj.__dict__) config = get_dict_config(obj.__dict__)
return "-".join([f"{k}={v}" for k, v in config.items()]) return "-".join([f"{k}={v}" for k, v in config.items()])
def obj_config_to_log(obj) -> str:
if isinstance(obj, dict):
config = get_dict_config(obj)
for k, v in config.items():
logging.info(f"{k} : {v}")
else:
config = get_dict_config(obj.__dict__)
for k, v in config.items():
logging.info(f"{k} : {v}")
def set_printing(logger_path):
"""
Set up printing options
"""
logging.root.handlers = []
logging_cfg = {
"level": logging.INFO,
"format": "%(asctime)s:%(levelname)s:%(message)s",
}
h_file = logging.FileHandler(logger_path)
h_stdout = logging.StreamHandler(sys.stdout)
if explaining_cfg.print == "file":
logging_cfg["handlers"] = [h_file]
elif explaining_cfg.print == "stdout":
logging_cfg["handlers"] = [h_stdout]
elif explaining_cfg.print == "both":
logging_cfg["handlers"] = [h_file, h_stdout]
else:
raise ValueError("Print option not supported")
logging.basicConfig(**logging_cfg)

64
main.py
View File

@ -3,8 +3,8 @@
# #
import copy import copy
import logging
import os import os
import time
import torch import torch
from torch_geometric import seed_everything from torch_geometric import seed_everything
@ -13,6 +13,7 @@ from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.utils.device import auto_select_device from torch_geometric.graphgym.utils.device import auto_select_device
from tqdm import tqdm
from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.cmd_args import parse_args
@ -20,60 +21,36 @@ from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.io import (dump_cfg, is_exists, from explaining_framework.utils.io import (dump_cfg, is_exists,
obj_config_to_str, read_json, obj_config_to_str, read_json,
write_json) set_printing, write_json)
# inference, time, force,
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file) outline = ExplainingOutline(args.explaining_cfg_file)
out_dir = os.path.join(
outline.explaining_cfg.out_dir,
outline.cfg.dataset.name,
outline.model_signature,
)
makedirs(out_dir)
dump_cfg(outline.cfg, os.path.join(out_dir, "config.yaml")) pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
explainer_path = os.path.join(
out_dir,
outline.explaining_cfg.explainer.name,
obj_config_to_str(outline.explaining_algorithm),
)
makedirs(explainer_path)
dump_cfg(
outline.explainer_cfg,
os.path.join(explainer_path, "explainer_cfg.yaml"),
)
dump_cfg(
outline.explaining_cfg,
os.path.join(explainer_path, explaining_cfg.cfg_dest),
)
item, index = outline.get_item() item, index = outline.get_item()
while not (item is None or index is None): while not (item is None or index is None):
for attack in outline.attacks: for attack in outline.attacks:
attack_path = os.path.join( attack_path = os.path.join(
out_dir, attack.__class__.__name__, obj_config_to_str(attack) outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
) )
makedirs(attack_path) makedirs(attack_path)
data_attack_path = os.path.join(attack_path, f"{index}.json") data_attack_path = os.path.join(attack_path, f"{index}.json")
data_attack = outline.get_attack( data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path attack=attack, item=item, path=data_attack_path
) )
item, index = outline.get_item() item, index = outline.get_item()
outline.reload_dataloader() outline.reload_dataloader()
makedirs(explainer_path)
item, index = outline.get_item() item, index = outline.get_item()
while not (item is None or index is None): while not (item is None or index is None):
for attack in outline.attacks: for attack in outline.attacks:
attack_path_ = os.path.join( attack_path_ = os.path.join(
explainer_path, attack.__class__.__name__, obj_config_to_str(attack) outline.explainer_path,
attack.__class__.__name__,
obj_config_to_str(attack),
) )
makedirs(attack_path_) makedirs(attack_path_)
data_attack_path_ = os.path.join(attack_path_, f"{index}.json") data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
@ -81,9 +58,15 @@ if __name__ == "__main__":
attack=attack, item=item, path=data_attack_path_ attack=attack, item=item, path=data_attack_path_
) )
exp = outline.get_explanation(item=attack_data, path=data_attack_path_) exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
pbar.update(1)
if exp is None:
continue
else:
for adjust in outline.adjusts: for adjust in outline.adjusts:
adjust_path = os.path.join( adjust_path = os.path.join(
attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust) attack_path_,
adjust.__class__.__name__,
obj_config_to_str(adjust),
) )
makedirs(adjust_path) makedirs(adjust_path)
exp_adjust_path = os.path.join(adjust_path, f"{index}.json") exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
@ -102,6 +85,9 @@ if __name__ == "__main__":
exp_masked = outline.get_threshold( exp_masked = outline.get_threshold(
item=exp_adjust, path=exp_masked_path item=exp_adjust, path=exp_masked_path
) )
if exp_masked is None:
continue
else:
for metric in outline.metrics: for metric in outline.metrics:
metric_path = os.path.join( metric_path = os.path.join(
masking_path, masking_path,
@ -113,15 +99,7 @@ if __name__ == "__main__":
out_metric = outline.get_metric( out_metric = outline.get_metric(
metric=metric, item=exp_masked, path=metric_path metric=metric, item=exp_masked, path=metric_path
) )
print("#################################")
print("Attack", attack.name)
print(
"ThresholdConfig",
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
)
print("Metric", metric.name)
print("Val", out_metric)
print("Index", index)
print("#################################")
item, index = outline.get_item() item, index = outline.get_item()
with open(os.path.join(outline.out_dir, "done"), "w") as f:
f.write("")