From 1c62a5a56148846ad9165c989b837a51809e446a Mon Sep 17 00:00:00 2001 From: araison Date: Fri, 13 Jan 2023 11:22:21 +0100 Subject: [PATCH] Fixing some minors bugs --- .../config/explaining_config.py | 2 +- .../utils/explaining/outline.py | 102 +++++++++--------- explaining_framework/utils/explanation/io.py | 18 ++-- explaining_framework/utils/io.py | 23 ++-- main.py | 8 +- 5 files changed, 82 insertions(+), 71 deletions(-) diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index cd68c42..c01397b 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -114,7 +114,7 @@ def set_cfg(explaining_cfg): explaining_cfg.threshold.config.type = "all" explaining_cfg.threshold.value = CN() - explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)] + explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(10)] explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50] # which objectives metrics to computes, either all or one in particular if implemented diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index ea364d3..eef12ab 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -534,29 +534,30 @@ class ExplainingOutline(object): self.graphstat = GraphStat() def get_explanation(self, item: Data, path: str): - if is_exists(path): + if is_exists( + path, + ): if self.explaining_cfg.explainer.force: try: explanation = _get_explanation(self.explainer, item) if explanation is None: logging.error( - " EXP::Generated; Path %s; FAILED", - (path), + " EXP::Generated; Path %s; FAILED" % (path,), ) else: logging.debug( - "EXP::Generated; Path %s; SUCCEEDED", - (path), + "EXP::Generated; Path %s; SUCCEEDED" % (path,), ) except Exception as e: logging.error(str(e)) return None else: - explanation = _load_explanation(path) + explanation = _load_explanation( + path, + ) logging.debug( - "EXP::Loaded; Path %s; SUCCEEDED", - (path), + "EXP::Loaded; Path %s; SUCCEEDED" % (path,), ) explanation = explanation.to(self.cfg.accelerator) else: @@ -565,8 +566,7 @@ class ExplainingOutline(object): get_pred(self.explainer, explanation) _save_explanation(explanation, path) logging.debug( - "EXP::Generated; Path %s; SUCCEEDED", - (path), + "EXP::Generated; Path %s; SUCCEEDED" % (path,), ) except Exception as e: logging.error(str(e)) @@ -575,19 +575,21 @@ class ExplainingOutline(object): return explanation def get_adjust(self, adjust: Adjust, item: Explanation, path: str): - if is_exists(path): + if is_exists( + path, + ): if self.explaining_cfg.explainer.force: exp_adjust = adjust.forward(item) logging.debug( - "ADJUST::Generated; Path %s; SUCCEEDED", - (path), + "ADJUST::Generated; Path %s; SUCCEEDED" % (path,), ) else: - exp_adjust = _load_explanation(path) + exp_adjust = _load_explanation( + path, + ) logging.debug( - "ADJUST::Loaded; Path %s; SUCCEEDED", - (path), + "ADJUST::Loaded; Path %s; SUCCEEDED" % (path,), ) else: @@ -595,75 +597,76 @@ class ExplainingOutline(object): get_pred(self.explainer, exp_adjust) _save_explanation(exp_adjust, path) logging.debug( - "ADJUST::Generated; Path %s; SUCCEEDED", - (path), + "ADJUST::Generated; Path %s; SUCCEEDED" % (path,), ) return exp_adjust def get_threshold(self, item: Explanation, path: str): - if is_exists(path): + if is_exists( + path, + ): if self.explaining_cfg.explainer.force: exp_threshold = self.explainer._post_process(item) logging.debug( - "THRESHOLD::Generated; Path %s; SUCCEEDED", - (path), + "THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,), ) else: - exp_threshold = _load_explanation(path) + exp_threshold = _load_explanation( + path, + ) logging.debug( - "THRESHOLD::Loaded; Path %s; SUCCEEDED", - (path), + "THRESHOLD::Loaded; Path %s; SUCCEEDED" % (path,), ) else: exp_threshold = self.explainer._post_process(item) get_pred(self.explainer, exp_threshold) _save_explanation(exp_threshold, path) logging.debug( - "THRESHOLD::Generated; Path %s; SUCCEEDED", - (path), + "THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,), ) if is_empty_graph(exp_threshold): logging.warning( - "THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED", - (path), + "THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED" % (path,), ) return None return exp_threshold def get_metric(self, metric: Metric, item: Explanation, path: str): - if is_exists(path): + if is_exists( + path, + ): if self.explaining_cfg.explainer.force: out_metric = metric.forward(item) logging.debug( - "METRIC::Generated; Path %s; SUCCEEDED", - (path), + "METRIC::Generated; Path %s; SUCCEEDED" % (path,), ) else: - out_metric = read_json(path) + out_metric = read_json( + path, + ) logging.debug( - "METRIC::Loaded; Path %s; SUCCEEDED", - (path), + "METRIC::Loaded; Path %s; SUCCEEDED" % (path,), ) else: out_metric = metric.forward(item) - data = {f"{metric.name}": out_metric} - write_json(data, path) if out_metric is None: logging.debug( - "METRIC::Generated; Path %s; FAILED", - (path), + "METRIC::Generated; Path %s; FAILED" % (path,), ) else: logging.debug( - "METRIC::Generated; Path %s; SUCCEEDED", - (path), + "METRIC::Generated; Path %s; SUCCEEDED" % (path,), ) + data = {f"{metric.name}": out_metric} + write_json(data, path) return out_metric def get_stat(self, item: Data, path: str): if self.graphstat is None: self.load_graphstat() - if is_exists(path): + if is_exists( + path, + ): pass else: if item.num_nodes <= 500: @@ -671,30 +674,31 @@ class ExplainingOutline(object): write_json(stat, path) def get_attack(self, attack: Attack, item: Data, path: str): - if is_exists(path): + if is_exists( + path, + ): if self.explaining_cfg.explainer.force: try: data_attack = attack.get_attacked_prediction(item) logging.debug( - "ATTACK::Generated %s; Path %s; SUCCEEDED", - (path), + "ATTACK::Generated; Path %s; SUCCEEDED" % (path,), ) except Exception as e: logging.error(str(e)) return None else: - data_attack = _load_explanation(path) + data_attack = _load_explanation( + path, + ) logging.debug( - "ATTACK::Generated %s; Path %s; SUCCEEDED", - (path), + "ATTACK::Generated; Path %s; SUCCEEDED" % (path,), ) else: try: data_attack = attack.get_attacked_prediction(item) _save_explanation(data_attack, path) logging.debug( - "ATTACK::Generated %s; Path %s; SUCCEEDED", - (path), + "ATTACK::Generated; Path %s; SUCCEEDED" % (path,), ) except Exception as e: logging.error(str(e)) diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index dd318d4..54c7625 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -2,11 +2,14 @@ import copy import json import os +import numpy as np import torch from torch_geometric.data import Data from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg +from explaining_framework.utils.io import read_json, write_json + def _get_explanation(explainer, item): explanation = explainer( @@ -27,9 +30,7 @@ def is_empty_graph(data: Data) -> bool: def get_pred(explainer, explanation): - pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[ - 0 - ] + pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index) setattr(explanation, "pred", pred) data = explanation.to_dict() if not data.get("node_mask") is None or not data.get("edge_mask") is None: @@ -38,7 +39,7 @@ def get_pred(explainer, explanation): edge_index=explanation.edge_index, node_mask=data.get("node_mask"), edge_mask=data.get("edge_mask"), - )[0] + ) setattr(explanation, "pred_exp", pred_masked) @@ -63,15 +64,12 @@ def _save_explanation(exp: Explanation, path: str) -> None: data = exp.clone().to_dict() for k, v in data.items(): if isinstance(v, torch.Tensor): - data[k] = v.detach().cpu().tolist() - - with open(path, "w") as f: - json.dump(data, f) + data[k] = v.clone().detach().cpu().tolist() + write_json(data, path) def _load_explanation(path: str) -> Explanation: - with open(path, "r") as f: - data = json.load(f) + data = read_json(data, path) for k, v in data.items(): if isinstance(v, list): if k == "edge_index" or k == "y": diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index 5356bf7..e1d0a69 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -4,6 +4,7 @@ import os import sys import yaml + from explaining_framework.config.explaining_config import explaining_cfg @@ -89,19 +90,23 @@ def set_printing(logger_path): Set up printing options """ - logging.root.handlers = [] - logging_cfg = { - "level": logging.INFO, - "format": "%(asctime)s::%(levelname)s::%(message)s", - } + logging.getLogger().setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s::%(levelname)s::%(message)s") + h_file = logging.FileHandler(logger_path) + h_file.setLevel(logging.DEBUG) + h_file.setFormatter(formatter) + h_stdout = logging.StreamHandler(sys.stdout) + h_stdout.setLevel(logging.INFO) + h_stdout.setFormatter(formatter) + if explaining_cfg.print == "file": - logging_cfg["handlers"] = [h_file] + logging.getLogger().addHandler(h_file) elif explaining_cfg.print == "stdout": - logging_cfg["handlers"] = [h_stdout] + logging.getLogger().addHandler(h_stdout) elif explaining_cfg.print == "both": - logging_cfg["handlers"] = [h_file, h_stdout] + logging.getLogger().addHandler(h_file) + logging.getLogger().addHandler(h_stdout) else: raise ValueError("Print option not supported") - logging.basicConfig(**logging_cfg) diff --git a/main.py b/main.py index 6e40c1e..04776db 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,7 @@ if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) for attack in outline.attacks: - logging.info(f"Running {attack.__class__.__name__}: {attack.name}") + logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ): @@ -42,7 +42,11 @@ if __name__ == "__main__": ) for attack in outline.attacks: - logging.info(f"Running {attack.__class__.__name__}: {attack.name}") + logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) + logging.info( + "Running %s: %s" + % (outline.explainer.__class__.__name__, outline.explaining_algorithm.name), + ) for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ):