Fixing bug

This commit is contained in:
araison 2023-02-12 14:51:15 +01:00
parent 408bab4bc4
commit 3cf4187164
1 changed files with 2 additions and 3 deletions

View File

@ -4,12 +4,11 @@ import os
import numpy as np import numpy as np
import torch import torch
from explaining_framework.utils.io import read_json, write_json
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.config import cfg
from explaining_framework.utils.io import read_json, write_json
def _get_explanation(explainer, item): def _get_explanation(explainer, item):
explanation = explainer( explanation = explainer(
@ -69,7 +68,7 @@ def _save_explanation(exp: Explanation, path: str) -> None:
def _load_explanation(path: str) -> Explanation: def _load_explanation(path: str) -> Explanation:
data = read_json(data, path) data = read_json(path)
for k, v in data.items(): for k, v in data.items():
if isinstance(v, list): if isinstance(v, list):
if k == "edge_index" or k == "y": if k == "edge_index" or k == "y":