from abc import ABC, abstractmethod import torch from torch_geometric.explain.explanation import Explanation from explaining_framework.utils.io import write_json class Metric(ABC): def __init__(self, name: str, model: torch.nn.Module = None, **kwargs): self.name = name self.model = model if self.is_model_needed() and model is None: raise ValueError(f"{self.name} needs model to perform measurements") self.authorized_metric = None def is_model_needed(self): if "fidelity" in self.name: return True else: return False @abstractmethod def load_metric(name: str): pass @abstractmethod def forward(exp: Explanation): pass def get_prediction(self, *args, **kwargs) -> torch.Tensor: r"""Returns the prediction of the model on the input graph. If the model mode is :obj:`"regression"`, the prediction is returned as a scalar value. If the model mode :obj:`"classification"`, the prediction is returned as the predicted class label. Args: *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. """ training = self.model.training self.model.eval() with torch.no_grad(): out = self.model(*args, **kwargs) self.model.train(training) return out