50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import copy
|
|
import json
|
|
import os
|
|
|
|
from torch_geometric.data import Data
|
|
from torch_geometric.explain.explanation import Explanation
|
|
|
|
|
|
def explanation_verification(exp: Explanation) -> bool:
|
|
is_good = True
|
|
masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)]
|
|
for mask in masks:
|
|
is_nan = mask.isnan().any().item()
|
|
is_inf = mask.isinf().any().item()
|
|
is_ok = exp.validate()
|
|
if is_nan or is_inf or not is_ok:
|
|
is_good = False
|
|
return is_good
|
|
else:
|
|
continue
|
|
return is_good
|
|
|
|
|
|
def save_explanation(exp: Explanation, path: str) -> None:
|
|
data = copy.copy(exp).to_dict()
|
|
for k, v in data.items():
|
|
if isinstance(v, torch.Tensor):
|
|
data[k] = v.detach().cpu().tolist()
|
|
with open(path, "w") as f:
|
|
json.dump(data, f)
|
|
|
|
|
|
def load_explanation(path: str) -> Explanation:
|
|
with open(path, "r") as f:
|
|
data = json.load(f)
|
|
for k, v in data.items():
|
|
if isinstance(v, list):
|
|
if k == "edge_index" or k == "y":
|
|
data[k] = torch.LongTensor(v)
|
|
else:
|
|
data[k] = torch.FloatTensor(v)
|
|
return Explanation.from_dict(data)
|
|
|
|
|
|
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
|
|
data = exp.to_dict()
|
|
for k, v in data.items():
|
|
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
|
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
|
|
return exp
|