Adding parsing files
This commit is contained in:
parent
7fe935dbad
commit
7f540f53d7
|
@ -0,0 +1,16 @@
|
|||
import copy
|
||||
|
||||
from torch import FloatTensor
|
||||
from torch.nn import ReLU
|
||||
|
||||
|
||||
def relu_mask(explanation: Explanation) -> Explanation:
|
||||
relu = ReLU()
|
||||
explanation_store = explanation._store
|
||||
raw_data = copy.copy(explanation._store)
|
||||
for k, v in explanation_store.items():
|
||||
if "mask" in k:
|
||||
explanation_store[k] = relu(v)
|
||||
explanation.__setattr__("raw_explanation", raw_data)
|
||||
explanation.__setattr__("raw_explanation_transform", "relu")
|
||||
return explanation
|
|
@ -0,0 +1,50 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
|
||||
|
||||
def explanation_verification(exp: Explanation) -> bool:
|
||||
is_good = True
|
||||
masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)]
|
||||
for mask in masks:
|
||||
is_nan = mask.isnan().any().item()
|
||||
is_inf = mask.isinf().any().item()
|
||||
is_ok = exp.validate()
|
||||
if is_nan or is_inf or not is_ok:
|
||||
is_good = False
|
||||
return is_good
|
||||
else:
|
||||
continue
|
||||
return is_good
|
||||
|
||||
|
||||
def save_explanation(exp: Explanation, path: str) -> None:
|
||||
data = copy.copy(exp).to_dict()
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
data[k] = v.detach().cpu().tolist()
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
def load_explanation(path: str) -> Explanation:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
for k, v in data.items():
|
||||
if isinstance(v, list):
|
||||
if k == "edge_index" or k == "y":
|
||||
data[k] = torch.LongTensor(v)
|
||||
else:
|
||||
data[k] = torch.FloatTensor(v)
|
||||
return Explanation.from_dict(data)
|
||||
|
||||
|
||||
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
|
||||
data = exp.to_dict()
|
||||
for k, v in data.items():
|
||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||
data[k] = data[k] / torch.norm(input=data[k], p=p, dim=None).item()
|
||||
return exp
|
Loading…
Reference in New Issue