Adding parsing files

This commit is contained in:
araison 2022-12-20 16:09:03 +01:00
parent 7fe935dbad
commit 7f540f53d7
3 changed files with 66 additions and 0 deletions

View File

@ -0,0 +1,16 @@
import copy
from torch import FloatTensor
from torch.nn import ReLU
def relu_mask(explanation: Explanation) -> Explanation:
relu = ReLU()
explanation_store = explanation._store
raw_data = copy.copy(explanation._store)
for k, v in explanation_store.items():
if "mask" in k:
explanation_store[k] = relu(v)
explanation.__setattr__("raw_explanation", raw_data)
explanation.__setattr__("raw_explanation_transform", "relu")
return explanation

View File

@ -0,0 +1,50 @@
import copy
import json
import os
from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation
def explanation_verification(exp: Explanation) -> bool:
is_good = True
masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)]
for mask in masks:
is_nan = mask.isnan().any().item()
is_inf = mask.isinf().any().item()
is_ok = exp.validate()
if is_nan or is_inf or not is_ok:
is_good = False
return is_good
else:
continue
return is_good
def save_explanation(exp: Explanation, path: str) -> None:
data = copy.copy(exp).to_dict()
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.detach().cpu().tolist()
with open(path, "w") as f:
json.dump(data, f)
def load_explanation(path: str) -> Explanation:
with open(path, "r") as f:
data = json.load(f)
for k, v in data.items():
if isinstance(v, list):
if k == "edge_index" or k == "y":
data[k] = torch.LongTensor(v)
else:
data[k] = torch.FloatTensor(v)
return Explanation.from_dict(data)
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
data = exp.to_dict()
for k, v in data.items():
if "_mask" in k and isinstance(v, torch.FloatTensor):
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
return exp