25 lines
810 B
Python
25 lines
810 B
Python
import torch
|
|
from torch_geometric.explain.explanation import Explanation
|
|
|
|
from explaining_framework.metric.base import Metric
|
|
|
|
|
|
class Sparsity(Metric):
|
|
def __init__(self, name):
|
|
super().__init__(name=name)
|
|
self.authorized_metric = ["l0"]
|
|
self.metric = self.load_metric(name)
|
|
|
|
def load_metric(self, name):
|
|
if name in self.authorized_metric:
|
|
if name == "l0":
|
|
metric = lambda x: torch.mean(mask.float()).item()
|
|
else:
|
|
raise ValueError(f"{name} is not supported yet")
|
|
|
|
def forward(self, exp: Explanation) -> float:
|
|
out = {}
|
|
for k, v in exp.to_dict().items():
|
|
if "mask" in k and torch.all(torch.logical_or(v == 0, v == 1)).item():
|
|
out[k] = torch.mean(v).item()
|
|
return out
|