Fixing bugs and adding new features

This commit is contained in:
araison 2022-12-30 19:50:42 +01:00
parent 1ab84c807c
commit 95a30f2c23
1 changed files with 11 additions and 11 deletions

View File

@ -1,24 +1,24 @@
import torch import torch
from explaining_framework.metric.base import Metric from explaining_framework.metric.base import Metric
class Sparsity(Metric): class Sparsity(Metric):
def __init__(self, name): def __init__(self, name):
super().__init__(name=name) super().__init__(name=name)
self.authorized_metric = ['l0'] self.authorized_metric = ["l0"]
self.metric = self.load_metric(name) self.metric = self.load_metric(name)
def load_metric(self,name): def load_metric(self, name):
if name in self.authorized_metric: if name in self.authorized_metric:
if name == 'l0': if name == "l0":
metric = lambda x : torch.mean(mask.float()).item() metric = lambda x: torch.mean(mask.float()).item()
else: else:
raise ValueError(f'{name} is not supported yet') raise ValueError(f"{name} is not supported yet")
def forward(self, exp: Explanation) -> float:
def forward(self, exp:Explanation) -> float:
out = {} out = {}
for k,v in exp.to_dict(): for k, v in exp.to_dict():
if 'mask' in if "mask" in k and v.dtype == torch.bool:
out[k] = torch.mean(mask.float()).item()
return torch.mean(mask.float()).item() return out