From 7e32d6fd3a72f9478c132775343b6c6b30ea49a3 Mon Sep 17 00:00:00 2001 From: araison Date: Wed, 4 Jan 2023 10:41:34 +0100 Subject: [PATCH] Reformating, fixing --- .../config/explaining_config.py | 27 ++-- .../utils/explaining/outline.py | 97 +++++++++---- main.py | 131 ++++++++---------- 3 files changed, 145 insertions(+), 110 deletions(-) diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index 6fdd3d8..67c1343 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -57,7 +57,7 @@ def set_cfg(explaining_cfg): explaining_cfg.dataset.name = "Cora" - explaining_cfg.dataset.specific_items = None + explaining_cfg.dataset.items = None explaining_cfg.run_topological_stat = True @@ -110,26 +110,33 @@ def set_cfg(explaining_cfg): # Thresholding options # ----------------------------------------------------------------------- # - explaining_cfg.threshold_config = CN() + explaining_cfg.threshold = CN() - explaining_cfg.threshold_config.threshold_type = None + explaining_cfg.threshold.config = CN() + explaining_cfg.threshold.config.type = "all" - explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)] - - explaining_cfg.threshold_config.relu_and_normalize = True - - # Select device: 'cpu', 'cuda', 'auto' + explaining_cfg.threshold.value = CN() + explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)] + explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50] # which objectives metrics to computes, either all or one in particular if implemented explaining_cfg.metrics = CN() - explaining_cfg.metrics.name = "all" + explaining_cfg.metrics.sparsity = CN() + explaining_cfg.metrics.sparsity.name = "all" + explaining_cfg.metrics.fidelity = CN() + explaining_cfg.metrics.fidelity.name = "all" + explaining_cfg.metrics.accuracy = CN() + explaining_cfg.metrics.accuracy.name = "all" # Whether or not recomputing metrics if they already exist - explaining_cfg.metrics.force = False + + explaining_cfg.adjust = CN() + explaining_cfg.adjust.strategy = "rpn" explaining_cfg.attack = CN() explaining_cfg.attack.name = "all" + # Select device: 'cpu', 'cuda', 'auto' explaining_cfg.accelerator = "auto" diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index b493b85..bcf93b4 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -3,17 +3,6 @@ import itertools from typing import Any from eixgnn.eixgnn import EiXGNN -from scgnn.scgnn import SCGNN -from torch_geometric import seed_everything -from torch_geometric.data import Batch, Data -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader - from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -22,6 +11,7 @@ from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper from explaining_framework.explainers.wrappers.from_graphxai import \ GraphXAIWrapper from explaining_framework.metric.accuracy import Accuracy +from explaining_framework.metric.base import Metric from explaining_framework.metric.fidelity import Fidelity from explaining_framework.metric.robust import Attack from explaining_framework.metric.sparsity import Sparsity @@ -34,6 +24,17 @@ from explaining_framework.utils.explanation.io import ( explanation_verification, get_pred) from explaining_framework.utils.io import (is_exists, obj_config_to_str, read_json, write_json, write_yaml) +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader all__captum = [ "LRP", @@ -120,6 +121,9 @@ class ExplainingOutline(object): self.attacks = None self.model_signature = None self.indexes = None + self.sparsities = None + self.fidelities = None + self.accuracies = None self.explaining_algorithm = None self.explainer = None self.adjusts = None @@ -248,8 +252,10 @@ class ExplainingOutline(object): ind = self.explaining_cfg.dataset.specific_items self.dataset = self.dataset[ind] - def load_dataset_to_dataloader(self): + def load_dataset_to_dataloader(self, to_iter=True): self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) + if to_iter: + self.dataset = iter(self.dataset) def load_explaining_algorithm(self): self.load_explainer_cfg() @@ -381,25 +387,27 @@ class ExplainingOutline(object): ] elif name is None: all_metrics = [] - self.accuraties = all_metrics + else: + raise ValueError( + f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset" + ) else: - raise ValueError( - f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset" - ) + all_metrics = [] + self.accuracies = all_metrics def load_metric(self): if self.cfg is None: self.load_cfg() if self.explaining_cfg is None: self.load_explaining_cfg() - if self.accuraties is None: + if self.accuracies is None: self.load_accuracy() if self.sparsities is None: self.load_sparsity() if self.fidelities is None: self.load_fidelity() - self.metrics = self.fidelities + self.accuraties + self.sparsities + self.metrics = self.fidelities + self.accuracies + self.sparsities def load_attack(self): if self.cfg is None: @@ -432,16 +440,18 @@ class ExplainingOutline(object): strategy = self.explaining_cfg.adjust.strategy if strategy == "all": self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters] - elif isinstance(name, str): - if name in all_adjusts_filters: - all_metrics = [Adjust(strategy=name)] + elif isinstance(strategy, str): + if strategy in all_adjusts_filters: + all_metrics = [Adjust(strategy=strategy)] else: raise ValueError( f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}" ) - elif isinstance(name, list): + elif isinstance(strategy, list): all_metrics = [ - Adjust(strategy=name_) for name_ in name if name_ in all_robust + Adjust(strategy=name_) + for name_ in strategy + if name_ in all_adjusts_filters ] elif name is None: all_metrics = [] @@ -450,7 +460,7 @@ class ExplainingOutline(object): def load_threshold(self): if self.explaining_cfg is None: self.load_explaining_cfg() - threshold_type = self.explaining_cfg.threshold_config.type + threshold_type = self.explaining_cfg.threshold.config.type if threshold_type == "all": th_hard = [ {"threshold_type": "hard", "value": th_value} @@ -523,11 +533,50 @@ class ExplainingOutline(object): explanation = _load_explanation(path) else: explanation = _get_explanation(self.explainer, item) + get_pred(self.explainer, explanation) _save_explanation(explanation, path) explanation = explanation.to(cfg.accelerator) return explanation + def get_adjust(self, adjust: Adjust, item: Explanation, path: str): + if is_exists(path): + if self.explaining_cfg.explainer.force: + exp_adjust = adjust.forward(item) + else: + exp_adjust = _load_explanation(path) + else: + exp_adjust = adjust.forward(item) + get_pred(self.explainer, exp_adjust) + _save_explanation(exp_adjust, path) + exp_adjust = exp_adjust.to(cfg.accelerator) + return exp_adjust + + def get_threshold(self, item: Explanation, path: str): + if is_exists(path): + if self.explaining_cfg.explainer.force: + exp_threshold = self.explainer._post_process(item) + else: + exp_threshold = _load_explanation(path) + else: + exp_threshold = self.explainer._post_process(item) + get_pred(self.explainer, exp_threshold) + _save_explanation(exp_threshold, path) + exp_threshold = exp_threshold.to(cfg.accelerator) + return exp_threshold + + def get_metric(self, metric: Metric, item: Explanation, path: str): + if is_exists(path): + if self.explaining_cfg.explainer.force: + out_metric = metric.forward(item) + else: + out_metric = read_json(path) + else: + out_metric = metric.forward(item) + data = {f"{metric.name}": out_metric} + write_json(data, path) + return out_metric + def get_stat(self, item: Data, path: str): if self.graphstat is None: self.load_graphstat() diff --git a/main.py b/main.py index ee5f984..5db2248 100644 --- a/main.py +++ b/main.py @@ -18,9 +18,6 @@ from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust -from explaining_framework.utils.explanation.io import ( - explanation_verification, get_explanation, get_pred, load_explanation, - save_explanation) from explaining_framework.utils.io import (is_exists, obj_config_to_str, read_json, write_json, write_yaml) @@ -31,81 +28,63 @@ if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) - # Load components - # RAJOUTER INDEXES + out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature) + makedirs(out_dir) - # Global path - global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm)) - makedirs(global_path) - write_yaml(cfg, os.path.join(global_path, "config.yaml")) - write_json(model_info, os.path.join(global_path, "info.json")) + write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml")) + write_json(outline.model_info, os.path.join(out_dir, "info.json")) - makedirs(global_path) - write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) - write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) + explainer_path = os.path.join( + out_dir, + outline.explaining_cfg.explainer.name + + "_" + + obj_config_to_str(outline.explaining_algorithm), + ) - global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm)) - makedirs(global_path) - # SET UP EXPLAINER - # Save explaining configuration - item,index = outline.get_item() - while not(item is None or index is None): - raw_path = os.path.join(global_path, "raw") - makedirs(raw_path) - explanation_path = os.path.join(save_raw_path, f"{index}.json") + makedirs(explainer_path) + write_yaml( + outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest) + ) + write_yaml( + outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml") + ) - for apply_relu in [True, False]: - for apply_absolute in [True, False]: - adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) - save_raw_path_ = os.path.join( - global_path, f"adjust-{obj_config_to_str(adjust)}" + specific_explainer_path = os.path.join( + explainer_path, obj_config_to_str(outline.explaining_algorithm) + ) + makedirs(specific_explainer_path) + + raw_path = os.path.join(specific_explainer_path, "raw") + makedirs(raw_path) + + item, index = outline.get_item() + while not (item is None or index is None): + explanation_path = os.path.join(raw_path, f"{index}.json") + raw_exp = outline.get_explanation(item=item, path=explanation_path) + for adjust in outline.adjusts: + adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}") + makedirs(adjust_path) + exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json") + exp_adjust = outline.get_adjust( + adjust=adjust, item=raw_exp, path=exp_adjust_path + ) + for threshold_conf in outline.thresholds_configs: + outline.set_explainer_threshold_config(threshold_conf) + masking_path = os.path.join( + save_raw_path_, + "-".join([f"{k}={v}" for k, v in threshold_conf.items()]), ) - explanation__ = copy.copy(explanation).to(cfg.accelerator) - makedirs(save_raw_path_) - explanation = adjust.forward(explanation__) - explanation_path = os.path.join(save_raw_path_, f"{index}.json") - get_pred(explainer, explanation__) - save_explanation(explanation__, explanation_path) - - for threshold_approach in ["hard", "topk", "topk_hard"]: - if threshold_approach == "hard": - threshold_values = explaining_cfg.threshold_config.value - elif "topk" in threshold_approach: - threshold_values = [3, 5, 10, 20] - for threshold_value in threshold_values: - - masking_path = os.path.join( - save_raw_path_, - f"threshold={threshold_approach}-value={threshold_value}", - ) - makedirs(masking_path) - exp_threshold_path = os.path.join(masking_path, f"{index}.json") - if is_exists(exp_threshold_path): - exp_threshold = load_explanation(exp_threshold_path) - else: - threshold_conf = { - "threshold_type": threshold_approach, - "value": threshold_value, - } - explainer.threshold_config = ThresholdConfig.cast( - threshold_conf - ) - - expl = copy.copy(explanation__).to(cfg.accelerator) - exp_threshold = explainer._post_process(expl) - exp_threshold = exp_threshold.to(cfg.accelerator) - get_pred(explainer, exp_threshold) - save_explanation(exp_threshold, exp_threshold_path) - for metric in metrics: - metric_path = os.path.join( - masking_path, f"{obj_config_to_str(metric)}" - ) - makedirs(metric_path) - if is_exists(os.path.join(metric_path, f"{index}.json")): - continue - else: - out = metric.forward(exp_threshold) - write_json( - {f"{metric.name}": out}, - os.path.join(metric_path, f"{index}.json"), - ) + makedirs(masking_path) + exp_masked_path = os.path.join(masking_path, f"{index}.json") + exp_masked = outline.get_threshold( + item=exp_adjust, path=exp_masked_path + ) + for metric in outline.metrics: + metric_path = os.path.join( + masking_path, f"{obj_config_to_str(metric)}" + ) + makedirs(metric_path) + metric_path = os.path.join(metric_path, f"{index}.json") + out_metric = outline.get_metric( + metric=metric, item=exp_masked, path=metric_path + )