Fixing bug
This commit is contained in:
parent
408bab4bc4
commit
3cf4187164
|
@ -4,12 +4,11 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from explaining_framework.utils.io import read_json, write_json
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
|
||||||
from explaining_framework.utils.io import read_json, write_json
|
|
||||||
|
|
||||||
|
|
||||||
def _get_explanation(explainer, item):
|
def _get_explanation(explainer, item):
|
||||||
explanation = explainer(
|
explanation = explainer(
|
||||||
|
@ -69,7 +68,7 @@ def _save_explanation(exp: Explanation, path: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
def _load_explanation(path: str) -> Explanation:
|
def _load_explanation(path: str) -> Explanation:
|
||||||
data = read_json(data, path)
|
data = read_json(path)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
if k == "edge_index" or k == "y":
|
if k == "edge_index" or k == "y":
|
||||||
|
|
Loading…
Reference in New Issue