diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index 54c7625..3e7fcb1 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -4,12 +4,11 @@ import os import numpy as np import torch +from explaining_framework.utils.io import read_json, write_json from torch_geometric.data import Data from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg -from explaining_framework.utils.io import read_json, write_json - def _get_explanation(explainer, item): explanation = explainer( @@ -69,7 +68,7 @@ def _save_explanation(exp: Explanation, path: str) -> None: def _load_explanation(path: str) -> Explanation: - data = read_json(data, path) + data = read_json(path) for k, v in data.items(): if isinstance(v, list): if k == "edge_index" or k == "y":