explaining_framework/explaining_framework/metric/sparsity.py
2023-01-08 20:12:38 +01:00

25 lines
810 B
Python

import torch
from torch_geometric.explain.explanation import Explanation
from explaining_framework.metric.base import Metric
class Sparsity(Metric):
def __init__(self, name):
super().__init__(name=name)
self.authorized_metric = ["l0"]
self.metric = self.load_metric(name)
def load_metric(self, name):
if name in self.authorized_metric:
if name == "l0":
metric = lambda x: torch.mean(mask.float()).item()
else:
raise ValueError(f"{name} is not supported yet")
def forward(self, exp: Explanation) -> float:
out = {}
for k, v in exp.to_dict().items():
if "mask" in k and torch.all(torch.logical_or(v == 0, v == 1)).item():
out[k] = torch.mean(v).item()
return out