Adding parsing files
This commit is contained in:
parent
7fe935dbad
commit
7f540f53d7
|
@ -0,0 +1,16 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from torch import FloatTensor
|
||||||
|
from torch.nn import ReLU
|
||||||
|
|
||||||
|
|
||||||
|
def relu_mask(explanation: Explanation) -> Explanation:
|
||||||
|
relu = ReLU()
|
||||||
|
explanation_store = explanation._store
|
||||||
|
raw_data = copy.copy(explanation._store)
|
||||||
|
for k, v in explanation_store.items():
|
||||||
|
if "mask" in k:
|
||||||
|
explanation_store[k] = relu(v)
|
||||||
|
explanation.__setattr__("raw_explanation", raw_data)
|
||||||
|
explanation.__setattr__("raw_explanation_transform", "relu")
|
||||||
|
return explanation
|
|
@ -0,0 +1,50 @@
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
|
||||||
|
|
||||||
|
def explanation_verification(exp: Explanation) -> bool:
|
||||||
|
is_good = True
|
||||||
|
masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)]
|
||||||
|
for mask in masks:
|
||||||
|
is_nan = mask.isnan().any().item()
|
||||||
|
is_inf = mask.isinf().any().item()
|
||||||
|
is_ok = exp.validate()
|
||||||
|
if is_nan or is_inf or not is_ok:
|
||||||
|
is_good = False
|
||||||
|
return is_good
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
return is_good
|
||||||
|
|
||||||
|
|
||||||
|
def save_explanation(exp: Explanation, path: str) -> None:
|
||||||
|
data = copy.copy(exp).to_dict()
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
data[k] = v.detach().cpu().tolist()
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_explanation(path: str) -> Explanation:
|
||||||
|
with open(path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
if k == "edge_index" or k == "y":
|
||||||
|
data[k] = torch.LongTensor(v)
|
||||||
|
else:
|
||||||
|
data[k] = torch.FloatTensor(v)
|
||||||
|
return Explanation.from_dict(data)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
|
||||||
|
data = exp.to_dict()
|
||||||
|
for k, v in data.items():
|
||||||
|
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||||
|
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
|
||||||
|
return exp
|
Loading…
Reference in New Issue