From 02bdbdc6ca169ab9ff4c77ab04c0c7ffd738ac86 Mon Sep 17 00:00:00 2001 From: araison Date: Sun, 8 Jan 2023 23:19:31 +0100 Subject: [PATCH] Reformating, adding logging and progress bar --- .../utils/explaining/outline.py | 143 +++++++++++++++++- explaining_framework/utils/explanation/io.py | 3 +- explaining_framework/utils/io.py | 38 +++++ main.py | 118 ++++++--------- 4 files changed, 227 insertions(+), 75 deletions(-) diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 5f47eea..1594dd2 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -1,11 +1,15 @@ import copy +import datetime import itertools +import logging +import os from typing import Any from eixgnn.eixgnn import EiXGNN from scgnn.scgnn import SCGNN from torch_geometric import seed_everything from torch_geometric.data import Batch, Data +from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.explain.explanation import Explanation @@ -35,9 +39,12 @@ from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.io import ( _get_explanation, _load_explanation, _save_explanation, - explanation_verification, get_pred) -from explaining_framework.utils.io import (is_exists, obj_config_to_str, - read_json, write_json, write_yaml) + explanation_verification, get_pred, is_empty_graph) +from explaining_framework.utils.io import (dump_cfg, is_exists, + obj_config_to_log, + obj_config_to_str, read_json, + set_printing, write_json, + write_yaml) all__captum = [ "LRP", @@ -155,6 +162,7 @@ class ExplainingOutline(object): self.load_adjust() self.load_threshold() self.load_graphstat() + self.setup_experiment() seed_everything(self.explaining_cfg.seed) @@ -168,7 +176,8 @@ class ExplainingOutline(object): self.load_dataset() try: item = next(self.dataset) - item = item.to(cfg.accelerator) + device = self.cfg.accelerator + item = item.to(device) return item except StopIteration: return None @@ -555,49 +564,118 @@ class ExplainingOutline(object): if is_exists(path): if self.explaining_cfg.explainer.force: explanation = _get_explanation(self.explainer, item) + if explanation is None: + logging.warning( + " EXP || Generated; Path %s; FAILED", + (path), + ) + else: + logging.debug( + "EXP || Generated; Path %s; SUCCEEDED", + (path), + ) else: explanation = _load_explanation(path) + logging.debug( + "EXP || Loaded; Path %s; SUCCEEDED", + (path), + ) explanation = explanation.to(self.cfg.accelerator) else: explanation = _get_explanation(self.explainer, item) get_pred(self.explainer, explanation) _save_explanation(explanation, path) + logging.debug( + "EXP || Generated; Path %s; SUCCEEDED", + (path), + ) + return explanation def get_adjust(self, adjust: Adjust, item: Explanation, path: str): if is_exists(path): if self.explaining_cfg.explainer.force: exp_adjust = adjust.forward(item) + logging.debug( + "ADJUST || Generated; Path %s; SUCCEEDED", + (path), + ) + else: exp_adjust = _load_explanation(path) + logging.debug( + "ADJUST || Loaded; Path %s; SUCCEEDED", + (path), + ) + else: exp_adjust = adjust.forward(item) get_pred(self.explainer, exp_adjust) _save_explanation(exp_adjust, path) + logging.debug( + "ADJUST || Generated; Path %s; SUCCEEDED", + (path), + ) return exp_adjust def get_threshold(self, item: Explanation, path: str): if is_exists(path): if self.explaining_cfg.explainer.force: exp_threshold = self.explainer._post_process(item) + logging.debug( + "THRESHOLD || Generated; Path %s; SUCCEEDED", + (path), + ) else: exp_threshold = _load_explanation(path) + logging.debug( + "THRESHOLD || Loaded; Path %s; SUCCEEDED", + (path), + ) else: exp_threshold = self.explainer._post_process(item) get_pred(self.explainer, exp_threshold) _save_explanation(exp_threshold, path) + logging.debug( + "THRESHOLD || Generated; Path %s; SUCCEEDED", + (path), + ) + if is_empty_graph(exp_threshold): + logging.warning( + "THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED", + (path), + ) + return None return exp_threshold def get_metric(self, metric: Metric, item: Explanation, path: str): if is_exists(path): if self.explaining_cfg.explainer.force: out_metric = metric.forward(item) + logging.debug( + "METRIC || Generated; Path %s; SUCCEEDED", + (path), + ) else: out_metric = read_json(path) + logging.debug( + "METRIC || Loaded; Path %s; SUCCEEDED", + (path), + ) else: out_metric = metric.forward(item) data = {f"{metric.name}": out_metric} write_json(data, path) + if out_metric is None: + logging.debug( + "METRIC || Generated; Path %s; FAILED", + (path), + ) + else: + logging.debug( + "METRIC || Generated; Path %s; SUCCEEDED", + (path), + ) return out_metric def get_stat(self, item: Data, path: str): @@ -614,9 +692,66 @@ class ExplainingOutline(object): if is_exists(path): if self.explaining_cfg.explainer.force: data_attack = attack.get_attacked_prediction(item) + logging.debug( + "ATTACK || Generated %s; Path %s; SUCCEEDED", + (path), + ) + else: data_attack = _load_explanation(path) + logging.debug( + "ATTACK || Generated %s; Path %s; SUCCEEDED", + (path), + ) else: data_attack = attack.get_attacked_prediction(item) _save_explanation(data_attack, path) + logging.debug( + "ATTACK || Generated %s; Path %s; SUCCEEDED", + (path), + ) return data_attack + + def setup_experiment(self): + now = datetime.datetime.now() + self.out_dir = os.path.join( + self.explaining_cfg.out_dir, + self.cfg.dataset.name, + self.model_signature, + ) + makedirs(self.out_dir) + + now_str = now.strftime("month=%m-day=%d-year=%Y-hour=%H-minute=%M-second=%S") + set_printing(f"{self.out_dir}/logging-{now_str}.log") + + dump_cfg(self.cfg, os.path.join(self.out_dir, "config.yaml")) + write_json(self.model_info, os.path.join(self.out_dir, "info.json")) + + self.explainer_path = os.path.join( + self.out_dir, + self.explaining_cfg.explainer.name, + obj_config_to_str(self.explaining_algorithm), + ) + makedirs(self.explainer_path) + dump_cfg( + self.explainer_cfg, + os.path.join(self.explainer_path, "explainer_cfg.yaml"), + ) + dump_cfg( + self.explaining_cfg, + os.path.join(self.explainer_path, self.explaining_cfg.cfg_dest), + ) + + logging.info("Setting up experiment") + logging.info("Date and Time: %s", now) + logging.info("Save experiment to %s", self.out_dir) + logging.info(self.cfg) + logging.info(self.explaining_cfg) + logging.info(self.explainer_cfg) + logging.info(self.model) + logging.info(obj_config_to_log(self.model_info)) + for metric in self.metrics + self.attacks: + logging.info(obj_config_to_str(metric)) + for threshold_conf in self.thresholds_configs: + logging.info(obj_config_to_str(threshold_conf)) + logging.info("Proceeding to explanations..") diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index 807214f..dd318d4 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -16,7 +16,6 @@ def _get_explanation(explainer, item): target=item.y, ) if not explanation_verification(explanation): - # WARNING + LOG return None else: explanation = explanation.to(cfg.accelerator) @@ -50,6 +49,8 @@ def explanation_verification(exp: Explanation) -> bool: is_nan = mask.isnan().any().item() is_inf = mask.isinf().any().item() is_ok = exp.validate() + is_const = mask.max() == mask.min() + if is_nan or is_inf or not is_ok: is_good = False return is_good diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index 146b3dc..28c1a0e 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -1,8 +1,12 @@ import json +import logging import os +import sys import yaml +from explaining_framework.config.explaining_config import explaining_cfg + def read_json(path: str) -> dict: with open(path, "r") as f: @@ -68,3 +72,37 @@ def obj_config_to_str(obj) -> str: else: config = get_dict_config(obj.__dict__) return "-".join([f"{k}={v}" for k, v in config.items()]) + + +def obj_config_to_log(obj) -> str: + if isinstance(obj, dict): + config = get_dict_config(obj) + for k, v in config.items(): + logging.info(f"{k} : {v}") + else: + config = get_dict_config(obj.__dict__) + for k, v in config.items(): + logging.info(f"{k} : {v}") + + +def set_printing(logger_path): + """ + Set up printing options + + """ + logging.root.handlers = [] + logging_cfg = { + "level": logging.INFO, + "format": "%(asctime)s:%(levelname)s:%(message)s", + } + h_file = logging.FileHandler(logger_path) + h_stdout = logging.StreamHandler(sys.stdout) + if explaining_cfg.print == "file": + logging_cfg["handlers"] = [h_file] + elif explaining_cfg.print == "stdout": + logging_cfg["handlers"] = [h_stdout] + elif explaining_cfg.print == "both": + logging_cfg["handlers"] = [h_file, h_stdout] + else: + raise ValueError("Print option not supported") + logging.basicConfig(**logging_cfg) diff --git a/main.py b/main.py index 77f6896..828e462 100644 --- a/main.py +++ b/main.py @@ -3,8 +3,8 @@ # import copy +import logging import os -import time import torch from torch_geometric import seed_everything @@ -13,6 +13,7 @@ from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.device import auto_select_device +from tqdm import tqdm from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args @@ -20,60 +21,36 @@ from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.io import (dump_cfg, is_exists, obj_config_to_str, read_json, - write_json) - -# inference, time, force, - + set_printing, write_json) if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) - out_dir = os.path.join( - outline.explaining_cfg.out_dir, - outline.cfg.dataset.name, - outline.model_signature, - ) - makedirs(out_dir) - dump_cfg(outline.cfg, os.path.join(out_dir, "config.yaml")) - write_json(outline.model_info, os.path.join(out_dir, "info.json")) - - explainer_path = os.path.join( - out_dir, - outline.explaining_cfg.explainer.name, - obj_config_to_str(outline.explaining_algorithm), - ) - makedirs(explainer_path) - dump_cfg( - outline.explainer_cfg, - os.path.join(explainer_path, "explainer_cfg.yaml"), - ) - dump_cfg( - outline.explaining_cfg, - os.path.join(explainer_path, explaining_cfg.cfg_dest), - ) + pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) item, index = outline.get_item() while not (item is None or index is None): for attack in outline.attacks: attack_path = os.path.join( - out_dir, attack.__class__.__name__, obj_config_to_str(attack) + outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) ) makedirs(attack_path) data_attack_path = os.path.join(attack_path, f"{index}.json") data_attack = outline.get_attack( attack=attack, item=item, path=data_attack_path ) + item, index = outline.get_item() outline.reload_dataloader() - makedirs(explainer_path) - item, index = outline.get_item() while not (item is None or index is None): for attack in outline.attacks: attack_path_ = os.path.join( - explainer_path, attack.__class__.__name__, obj_config_to_str(attack) + outline.explainer_path, + attack.__class__.__name__, + obj_config_to_str(attack), ) makedirs(attack_path_) data_attack_path_ = os.path.join(attack_path_, f"{index}.json") @@ -81,47 +58,48 @@ if __name__ == "__main__": attack=attack, item=item, path=data_attack_path_ ) exp = outline.get_explanation(item=attack_data, path=data_attack_path_) - for adjust in outline.adjusts: - adjust_path = os.path.join( - attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust) - ) - makedirs(adjust_path) - exp_adjust_path = os.path.join(adjust_path, f"{index}.json") - exp_adjust = outline.get_adjust( - adjust=adjust, item=exp, path=exp_adjust_path - ) - for threshold_conf in outline.thresholds_configs: - outline.set_explainer_threshold_config(threshold_conf) - masking_path = os.path.join( - adjust_path, - "ThresholdConfig", - obj_config_to_str(threshold_conf), + pbar.update(1) + if exp is None: + continue + else: + for adjust in outline.adjusts: + adjust_path = os.path.join( + attack_path_, + adjust.__class__.__name__, + obj_config_to_str(adjust), ) - makedirs(masking_path) - exp_masked_path = os.path.join(masking_path, f"{index}.json") - exp_masked = outline.get_threshold( - item=exp_adjust, path=exp_masked_path + makedirs(adjust_path) + exp_adjust_path = os.path.join(adjust_path, f"{index}.json") + exp_adjust = outline.get_adjust( + adjust=adjust, item=exp, path=exp_adjust_path ) - for metric in outline.metrics: - metric_path = os.path.join( - masking_path, - metric.__class__.__name__, - obj_config_to_str(metric), - ) - makedirs(metric_path) - metric_path = os.path.join(metric_path, f"{index}.json") - out_metric = outline.get_metric( - metric=metric, item=exp_masked, path=metric_path - ) - print("#################################") - print("Attack", attack.name) - print( + for threshold_conf in outline.thresholds_configs: + outline.set_explainer_threshold_config(threshold_conf) + masking_path = os.path.join( + adjust_path, "ThresholdConfig", - "-".join([f"{k}={v}" for k, v in threshold_conf.items()]), + obj_config_to_str(threshold_conf), ) - print("Metric", metric.name) - print("Val", out_metric) - print("Index", index) - print("#################################") + makedirs(masking_path) + exp_masked_path = os.path.join(masking_path, f"{index}.json") + exp_masked = outline.get_threshold( + item=exp_adjust, path=exp_masked_path + ) + if exp_masked is None: + continue + else: + for metric in outline.metrics: + metric_path = os.path.join( + masking_path, + metric.__class__.__name__, + obj_config_to_str(metric), + ) + makedirs(metric_path) + metric_path = os.path.join(metric_path, f"{index}.json") + out_metric = outline.get_metric( + metric=metric, item=exp_masked, path=metric_path + ) item, index = outline.get_item() + with open(os.path.join(outline.out_dir, "done"), "w") as f: + f.write("")