Fixing bug

This commit is contained in:
araison 2023-02-12 14:51:15 +01:00
parent 408bab4bc4
commit 3cf4187164

View File

@ -4,12 +4,11 @@ import os
import numpy as np
import torch
from explaining_framework.utils.io import read_json, write_json
from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from explaining_framework.utils.io import read_json, write_json
def _get_explanation(explainer, item):
explanation = explainer(
@ -69,7 +68,7 @@ def _save_explanation(exp: Explanation, path: str) -> None:
def _load_explanation(path: str) -> Explanation:
data = read_json(data, path)
data = read_json(path)
for k, v in data.items():
if isinstance(v, list):
if k == "edge_index" or k == "y":