52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
from torch_geometric.explain.explanation import Explanation
|
|
|
|
from explaining_framework.utils.io import write_json
|
|
|
|
|
|
class Metric(ABC):
|
|
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
|
|
self.name = name
|
|
self.model = model
|
|
if self.is_model_needed() and model is None:
|
|
raise ValueError(f"{self.name} needs model to perform measurements")
|
|
self.authorized_metric = None
|
|
|
|
def is_model_needed(self):
|
|
if "fidelity" in self.name:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
@abstractmethod
|
|
def load_metric(name: str):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def forward(exp: Explanation):
|
|
pass
|
|
|
|
def get_prediction(self, *args, **kwargs) -> torch.Tensor:
|
|
r"""Returns the prediction of the model on the input graph.
|
|
If the model mode is :obj:`"regression"`, the prediction is returned as
|
|
a scalar value.
|
|
If the model mode :obj:`"classification"`, the prediction is returned
|
|
as the predicted class label.
|
|
Args:
|
|
*args: Arguments passed to the model.
|
|
**kwargs (optional): Additional keyword arguments passed to the
|
|
model.
|
|
"""
|
|
training = self.model.training
|
|
self.model.eval()
|
|
|
|
with torch.no_grad():
|
|
out = self.model(*args, **kwargs)
|
|
|
|
self.model.train(training)
|
|
|
|
return out
|
|
|
|
|