diff --git a/explaining_framework/utils/explanation/__init__.py b/explaining_framework/utils/explanation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/explaining_framework/utils/explanation/adjust.py b/explaining_framework/utils/explanation/adjust.py new file mode 100644 index 0000000..0e641ee --- /dev/null +++ b/explaining_framework/utils/explanation/adjust.py @@ -0,0 +1,16 @@ +import copy + +from torch import FloatTensor +from torch.nn import ReLU + + +def relu_mask(explanation: Explanation) -> Explanation: + relu = ReLU() + explanation_store = explanation._store + raw_data = copy.copy(explanation._store) + for k, v in explanation_store.items(): + if "mask" in k: + explanation_store[k] = relu(v) + explanation.__setattr__("raw_explanation", raw_data) + explanation.__setattr__("raw_explanation_transform", "relu") + return explanation diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py new file mode 100644 index 0000000..7c56227 --- /dev/null +++ b/explaining_framework/utils/explanation/io.py @@ -0,0 +1,50 @@ +import copy +import json +import os + +from torch_geometric.data import Data +from torch_geometric.explain.explanation import Explanation + + +def explanation_verification(exp: Explanation) -> bool: + is_good = True + masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)] + for mask in masks: + is_nan = mask.isnan().any().item() + is_inf = mask.isinf().any().item() + is_ok = exp.validate() + if is_nan or is_inf or not is_ok: + is_good = False + return is_good + else: + continue + return is_good + + +def save_explanation(exp: Explanation, path: str) -> None: + data = copy.copy(exp).to_dict() + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.detach().cpu().tolist() + with open(path, "w") as f: + json.dump(data, f) + + +def load_explanation(path: str) -> Explanation: + with open(path, "r") as f: + data = json.load(f) + for k, v in data.items(): + if isinstance(v, list): + if k == "edge_index" or k == "y": + data[k] = torch.LongTensor(v) + else: + data[k] = torch.FloatTensor(v) + return Explanation.from_dict(data) + + +def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation: + data = exp.to_dict() + for k, v in data.items(): + if "_mask" in k and isinstance(v, torch.FloatTensor): + data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item() + return exp