diff --git a/explaining_framework/metric/sparsity.py b/explaining_framework/metric/sparsity.py index ea0ee77..b3f7a61 100644 --- a/explaining_framework/metric/sparsity.py +++ b/explaining_framework/metric/sparsity.py @@ -1,24 +1,24 @@ import torch + from explaining_framework.metric.base import Metric class Sparsity(Metric): def __init__(self, name): super().__init__(name=name) - self.authorized_metric = ['l0'] + self.authorized_metric = ["l0"] self.metric = self.load_metric(name) - def load_metric(self,name): + def load_metric(self, name): if name in self.authorized_metric: - if name == 'l0': - metric = lambda x : torch.mean(mask.float()).item() + if name == "l0": + metric = lambda x: torch.mean(mask.float()).item() else: - raise ValueError(f'{name} is not supported yet') + raise ValueError(f"{name} is not supported yet") - - def forward(self, exp:Explanation) -> float: + def forward(self, exp: Explanation) -> float: out = {} - for k,v in exp.to_dict(): - if 'mask' in - - return torch.mean(mask.float()).item() + for k, v in exp.to_dict(): + if "mask" in k and v.dtype == torch.bool: + out[k] = torch.mean(mask.float()).item() + return out