explaining_framework/explaining_framework/metric/base.py

52 lines
1.4 KiB
Python

from abc import ABC, abstractmethod
import torch
from torch_geometric.explain.explanation import Explanation
from explaining_framework.utils.io import write_json
class Metric(ABC):
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
self.name = name
self.model = model
if self.is_model_needed() and model is None:
raise ValueError(f"{self.name} needs model to perform measurements")
self.authorized_metric = None
def is_model_needed(self):
if "fidelity" in self.name:
return True
else:
return False
@abstractmethod
def load_metric(name: str):
pass
@abstractmethod
def forward(exp: Explanation):
pass
def get_prediction(self, *args, **kwargs) -> torch.Tensor:
r"""Returns the prediction of the model on the input graph.
If the model mode is :obj:`"regression"`, the prediction is returned as
a scalar value.
If the model mode :obj:`"classification"`, the prediction is returned
as the predicted class label.
Args:
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the
model.
"""
training = self.model.training
self.model.eval()
with torch.no_grad():
out = self.model(*args, **kwargs)
self.model.train(training)
return out