New fixes and new features

This commit is contained in:
araison 2022-12-16 13:52:24 +01:00
parent 45730328a7
commit eca200fe88
2 changed files with 71 additions and 14 deletions

View File

@ -1,8 +1,16 @@
import torch
import torch.nn.functional as F
from torch.nn import KLDivLoss, Softmax
from base import Metric
print("NUM CLASSES cfg dataset")
NUM_CLASS = 5
def softmax(data):
return Softmax(dim=1)(data)
class Fidelity(Metric):
def __init__(name: str, model: torch.nn.Module, mask_type: str):
@ -12,10 +20,6 @@ class Fidelity(Metric):
"fidelity_minus",
"fidelity_plus_prob",
"fidelity_minus_prob",
"fidelity_plus_model",
"fidelity_minus_model",
"fidelity_plus_prob_model",
"fidelity_minus_prob_model",
"infidelity_KL",
]
@ -26,7 +30,7 @@ class Fidelity(Metric):
self.s_exp_sub_c = None
self.s_initial_data = None
def _fidelity_plus(self, exp) -> float:
def _score_check(self):
if any(
[
attr is None
@ -40,8 +44,51 @@ class Fidelity(Metric):
]
):
self.score(exp)
else:
fid = self.s_initial_data - self.s_exp_sub_c
def _fidelity_plus(self, exp: Explanation) -> float:
self._score_check()
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
return torch.mean(
(exp.y == inferred_class_initial).long()
- (exp.y == inferred_class_exp).long()
).item()
def _fidelity_minus(self, exp: Explanation) -> float:
self._score_check()
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
return torch.mean(
(exp.y == inferred_class_initial).long()
- (exp.y == inferred_class_exp).long()
).item()
def _fidelity_plus_prob(self, exp: Explanation) -> float:
self._score_check()
# one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS)
prob_initial = softmax(self.s_initial_data)
prob_exp = softmax(self.s_exp_sub_c)
size = int(exp.y.size(0))
prob_initial = prob_initial[torch.arange(size), exp.y]
prob_exp = prob_exp[torch.arange(size), exp.y]
return torch.mean(
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
).item()
def _fidelity_minus_prob(self, exp: Explanation) -> float:
self._score_check()
prob_initial = softmax(self.s_initial_data)
prob_exp = softmax(self.s_exp_sub)
size = int(exp.y.size(0))
prob_initial = prob_initial[torch.arange(size), exp.y]
prob_exp = prob_exp[torch.arange(size), exp.y]
return torch.mean(
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
).item()
def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph()
@ -53,14 +100,16 @@ class Fidelity(Metric):
def load_metric(name):
if name in self.authorized_metric:
if name == "fidelity_plus":
self.metric = eval("sklearn.metric.{name}")
self.metric = lambda exp: self._fidelity_plus(exp)
if name == "fidelity_minus":
self.metric = lambda exp: self._fidelity_minus(exp)
if name == "fidelity_plus_prob":
self.metric = lambda exp: self._fidelity_plus_prob(exp)
if name == "fidelity_minus_prob":
self.metric = lambda exp: self._fidelity_minus_prob(exp)
else:
raise ValueError(f"{name} is not supported")
def compute(self, mask, target: Tensor) -> float:
if mask.type() == torch.bool and target.type() == torch.bool:
return self.metric(y_pred=mask, y_true=target)
return self.metric
def __call__(self, exp: Explanation):
pass
return self.metric(exp)

View File

@ -0,0 +1,8 @@
import torch
class Sparsity(Metric):
def __init__(self,name):
super().__init__(name=name,model=None)
def