From 10baa1d443e55feb902217f33bc55704b01b4acb Mon Sep 17 00:00:00 2001 From: araison Date: Wed, 4 Jan 2023 12:56:37 +0100 Subject: [PATCH] Reformating, fixing --- .../config/explaining_config.py | 2 +- explaining_framework/metric/robust.py | 10 ++-- .../stats/graph/graph_stat.py | 4 +- .../utils/explaining/outline.py | 60 ++++++++++--------- .../utils/explanation/adjust.py | 4 +- explaining_framework/utils/explanation/io.py | 9 ++- main.py | 5 +- 7 files changed, 49 insertions(+), 45 deletions(-) diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index 67c1343..ce95265 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -98,7 +98,7 @@ def set_cfg(explaining_cfg): explaining_cfg.model_config = CN() # Do not modify it, will be handled by dataset , assuming one dataset = one learning task - explaining_cfg.model_config.mode = None + explaining_cfg.model_config.mode = "regression" # Do not modify it, will be handled by dataset , assuming one dataset = one learning task explaining_cfg.model_config.task_level = None diff --git a/explaining_framework/metric/robust.py b/explaining_framework/metric/robust.py index 5c7aaee..abb99b4 100644 --- a/explaining_framework/metric/robust.py +++ b/explaining_framework/metric/robust.py @@ -39,7 +39,7 @@ class FGSM(Metric): self.zero_thresh = 10**-6 def forward(self, input, target, epsilon: float) -> Explanation: - input_ = copy.copy(input) + input_ = input.clone() grad = compute_gradient( model=self.model, inp=input_, target=target, loss=self.loss ) @@ -168,24 +168,24 @@ class Attack(Metric): def _gaussian_noise(self, exp) -> Explanation: x = torch.clone(exp.x) x = x + torch.randn(*x.shape) - exp_ = copy.copy(exp) + exp_ = exp.clone() exp_.x = x return exp_ def _add_edge(self, exp, p: float) -> Explanation: - exp_ = copy.copy(exp) + exp_ = exp.clone() exp_.edge_index, _ = add_random_edge( exp_.edge_index, p=p, num_nodes=exp_.x.shape[0] ) return exp_ def _remove_edge(self, exp, p: float) -> Explanation: - exp_ = copy.copy(exp) + exp_ = exp.clone() exp_.edge_index, _ = dropout_edge(exp_.edge_index, p=p) return exp_ def _remove_node(self, exp, p: float) -> Explanation: - exp_ = copy.copy(exp) + exp_ = exp.clone() exp_.edge_index, _, _ = dropout_node( exp_.edge_index, p=p, num_nodes=exp_.x.shape[0] ) diff --git a/explaining_framework/stats/graph/graph_stat.py b/explaining_framework/stats/graph/graph_stat.py index 41cc3ca..f7aa775 100644 --- a/explaining_framework/stats/graph/graph_stat.py +++ b/explaining_framework/stats/graph/graph_stat.py @@ -153,7 +153,7 @@ class GraphStat(object): return maps def __call__(self, data): - data_ = copy.copy(data) + data_ = data.clone() datahash = hash(data.__repr__) stats = {} for k, v in self.maps.items(): @@ -161,7 +161,7 @@ class GraphStat(object): _data_ = to_networkx(data) _data_ = _data_.to_undirected() elif k == "torch_geometric": - _data_ = copy.copy(data) + _data_ = data.clone() for name, func in v.items(): try: val = func(_data_) diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index bcf93b4..99707ac 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -3,6 +3,18 @@ import itertools from typing import Any from eixgnn.eixgnn import EiXGNN +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader + from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -24,17 +36,6 @@ from explaining_framework.utils.explanation.io import ( explanation_verification, get_pred) from explaining_framework.utils.io import (is_exists, obj_config_to_str, read_json, write_json, write_yaml) -from scgnn.scgnn import SCGNN -from torch_geometric import seed_everything -from torch_geometric.data import Batch, Data -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.explain.explanation import Explanation -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader all__captum = [ "LRP", @@ -133,6 +134,7 @@ class ExplainingOutline(object): self.load_explaining_cfg() self.load_model_info() self.load_cfg() + self.load_cfg_to_explaining_cfg() self.load_dataset() self.load_model() self.load_model_to_hardware() @@ -165,8 +167,10 @@ class ExplainingOutline(object): return None def load_indexes(self): - if not self.explaining_cfg.dataset.specific_items is None: - indexes = explaining_cfg.dataset.specific_items + + items = self.explaining_cfg.dataset.items + if isinstance(items, (list, int)): + indexes = items else: indexes = list(range(len(self.dataset))) self.indexes = iter(indexes) @@ -195,15 +199,20 @@ class ExplainingOutline(object): self.model_signature = info.get_model_signature() def load_cfg(self): - cfg.set_new_allowed(True) cfg.merge_from_file(self.model_info["cfg_path"]) self.cfg = cfg def load_explaining_cfg(self): - explaining_cfg.set_new_allowed(True) explaining_cfg.merge_from_file(self.explaining_cfg_path) self.explaining_cfg = explaining_cfg + def load_cfg_to_explaining_cfg(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.load_explaining_cfg() + self.explaining_cfg.model_config.task_level = self.cfg.dataset.task + def load_explainer_cfg(self): if self.explaining_cfg is None: self.load_explaining_cfg() @@ -217,11 +226,9 @@ class ExplainingOutline(object): self.explainer_cfg = None else: if self.explaining_cfg.explainer.name == "EIXGNN": - eixgnn_cfg.set_new_allowed(True) eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) self.explainer_cfg = eixgnn_cfg elif self.explaining_cfg.explainer.name == "SCGNN": - scgnn_cfg.set_new_allowed(True) scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) self.explainer_cfg = scgnn_cfg @@ -245,12 +252,13 @@ class ExplainingOutline(object): f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." ) self.dataset = create_dataset() - if isinstance(self.explaining_cfg.dataset.specific_items, int): - ind = self.explaining_cfg.dataset.specific_items - self.dataset = self.dataset[ind : ind + 1] - elif isinstance(self.explaining_cfg.dataset.specific_items, list): - ind = self.explaining_cfg.dataset.specific_items - self.dataset = self.dataset[ind] + items = self.explaining_cfg.dataset.items + print(items) + print(type(items)) + if isinstance(items, int): + self.dataset = self.dataset[items : items + 1] + elif isinstance(items, list): + self.dataset = self.dataset[items] def load_dataset_to_dataloader(self, to_iter=True): self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) @@ -308,7 +316,7 @@ class ExplainingOutline(object): ), model_config=dict( mode="regression", - task_level=self.cfg.dataset.task, + task_level=self.explaining_cfg.model_config.task_level, return_type=self.explaining_cfg.model_config.return_type, ), ) @@ -535,8 +543,6 @@ class ExplainingOutline(object): explanation = _get_explanation(self.explainer, item) get_pred(self.explainer, explanation) _save_explanation(explanation, path) - explanation = explanation.to(cfg.accelerator) - return explanation def get_adjust(self, adjust: Adjust, item: Explanation, path: str): @@ -549,7 +555,6 @@ class ExplainingOutline(object): exp_adjust = adjust.forward(item) get_pred(self.explainer, exp_adjust) _save_explanation(exp_adjust, path) - exp_adjust = exp_adjust.to(cfg.accelerator) return exp_adjust def get_threshold(self, item: Explanation, path: str): @@ -562,7 +567,6 @@ class ExplainingOutline(object): exp_threshold = self.explainer._post_process(item) get_pred(self.explainer, exp_threshold) _save_explanation(exp_threshold, path) - exp_threshold = exp_threshold.to(cfg.accelerator) return exp_threshold def get_metric(self, metric: Metric, item: Explanation, path: str): diff --git a/explaining_framework/utils/explanation/adjust.py b/explaining_framework/utils/explanation/adjust.py index b012811..0cdbe98 100644 --- a/explaining_framework/utils/explanation/adjust.py +++ b/explaining_framework/utils/explanation/adjust.py @@ -14,7 +14,7 @@ class Adjust(object): self.strategy = strategy def forward(self, exp: Explanation) -> Explanation: - exp_ = copy.copy(exp) + exp_ = exp.clone() _store = exp_.to_dict() for k, v in _store.items(): if "mask" in k: @@ -41,7 +41,7 @@ class Adjust(object): return mask_ def normalize(self, mask: FloatTensor) -> FloatTensor: - norm = torch.norm(mask, p="inf") + norm = torch.norm(mask, p=float("inf")) if norm.item() > 0: mask_ = mask / norm.item() return mask_ diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index d532493..807214f 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -5,6 +5,7 @@ import os import torch from torch_geometric.data import Data from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg def _get_explanation(explainer, item): @@ -18,6 +19,7 @@ def _get_explanation(explainer, item): # WARNING + LOG return None else: + explanation = explanation.to(cfg.accelerator) return explanation @@ -47,9 +49,8 @@ def explanation_verification(exp: Explanation) -> bool: for mask in masks: is_nan = mask.isnan().any().item() is_inf = mask.isinf().any().item() - is_const = mask.max() == mask.min() is_ok = exp.validate() - if is_nan or is_inf or not is_ok or is_const: + if is_nan or is_inf or not is_ok: is_good = False return is_good else: @@ -58,7 +59,7 @@ def explanation_verification(exp: Explanation) -> bool: def _save_explanation(exp: Explanation, path: str) -> None: - data = copy.copy(exp).to_dict() + data = exp.clone().to_dict() for k, v in data.items(): if isinstance(v, torch.Tensor): data[k] = v.detach().cpu().tolist() @@ -77,5 +78,3 @@ def _load_explanation(path: str) -> Explanation: else: data[k] = torch.FloatTensor(v) return Explanation.from_dict(data) - - diff --git a/main.py b/main.py index 5db2248..8e25d9b 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ from explaining_framework.utils.io import (is_exists, obj_config_to_str, if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) + print(outline.explaining_cfg) out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature) makedirs(out_dir) @@ -64,14 +65,14 @@ if __name__ == "__main__": for adjust in outline.adjusts: adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}") makedirs(adjust_path) - exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json") + exp_adjust_path = os.path.join(adjust_path, f"{index}.json") exp_adjust = outline.get_adjust( adjust=adjust, item=raw_exp, path=exp_adjust_path ) for threshold_conf in outline.thresholds_configs: outline.set_explainer_threshold_config(threshold_conf) masking_path = os.path.join( - save_raw_path_, + adjust_path, "-".join([f"{k}={v}" for k, v in threshold_conf.items()]), ) makedirs(masking_path)