Fixing bugs and adding new features

This commit is contained in:
araison 2022-12-30 19:50:42 +01:00
parent 1ab84c807c
commit 95a30f2c23

View File

@ -1,24 +1,24 @@
import torch
from explaining_framework.metric.base import Metric
class Sparsity(Metric):
def __init__(self, name):
super().__init__(name=name)
self.authorized_metric = ['l0']
self.authorized_metric = ["l0"]
self.metric = self.load_metric(name)
def load_metric(self,name):
def load_metric(self, name):
if name in self.authorized_metric:
if name == 'l0':
metric = lambda x : torch.mean(mask.float()).item()
if name == "l0":
metric = lambda x: torch.mean(mask.float()).item()
else:
raise ValueError(f'{name} is not supported yet')
raise ValueError(f"{name} is not supported yet")
def forward(self, exp:Explanation) -> float:
def forward(self, exp: Explanation) -> float:
out = {}
for k,v in exp.to_dict():
if 'mask' in
return torch.mean(mask.float()).item()
for k, v in exp.to_dict():
if "mask" in k and v.dtype == torch.bool:
out[k] = torch.mean(mask.float()).item()
return out