diff --git a/explaining_framework/metric/fidelity.py b/explaining_framework/metric/fidelity.py index c80cccb..96cf944 100644 --- a/explaining_framework/metric/fidelity.py +++ b/explaining_framework/metric/fidelity.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from torch.nn import KLDivLoss, Softmax from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg +from torch import Tensor from explaining_framework.metric.base import Metric