New fixes and new features
This commit is contained in:
parent
45730328a7
commit
eca200fe88
@ -1,8 +1,16 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import KLDivLoss, Softmax
|
||||
|
||||
from base import Metric
|
||||
|
||||
print("NUM CLASSES cfg dataset")
|
||||
NUM_CLASS = 5
|
||||
|
||||
|
||||
def softmax(data):
|
||||
return Softmax(dim=1)(data)
|
||||
|
||||
|
||||
class Fidelity(Metric):
|
||||
def __init__(name: str, model: torch.nn.Module, mask_type: str):
|
||||
@ -12,10 +20,6 @@ class Fidelity(Metric):
|
||||
"fidelity_minus",
|
||||
"fidelity_plus_prob",
|
||||
"fidelity_minus_prob",
|
||||
"fidelity_plus_model",
|
||||
"fidelity_minus_model",
|
||||
"fidelity_plus_prob_model",
|
||||
"fidelity_minus_prob_model",
|
||||
"infidelity_KL",
|
||||
]
|
||||
|
||||
@ -26,7 +30,7 @@ class Fidelity(Metric):
|
||||
self.s_exp_sub_c = None
|
||||
self.s_initial_data = None
|
||||
|
||||
def _fidelity_plus(self, exp) -> float:
|
||||
def _score_check(self):
|
||||
if any(
|
||||
[
|
||||
attr is None
|
||||
@ -40,8 +44,51 @@ class Fidelity(Metric):
|
||||
]
|
||||
):
|
||||
self.score(exp)
|
||||
else:
|
||||
fid = self.s_initial_data - self.s_exp_sub_c
|
||||
|
||||
def _fidelity_plus(self, exp: Explanation) -> float:
|
||||
self._score_check()
|
||||
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
|
||||
return torch.mean(
|
||||
(exp.y == inferred_class_initial).long()
|
||||
- (exp.y == inferred_class_exp).long()
|
||||
).item()
|
||||
|
||||
def _fidelity_minus(self, exp: Explanation) -> float:
|
||||
self._score_check()
|
||||
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
|
||||
return torch.mean(
|
||||
(exp.y == inferred_class_initial).long()
|
||||
- (exp.y == inferred_class_exp).long()
|
||||
).item()
|
||||
|
||||
def _fidelity_plus_prob(self, exp: Explanation) -> float:
|
||||
self._score_check()
|
||||
# one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS)
|
||||
prob_initial = softmax(self.s_initial_data)
|
||||
prob_exp = softmax(self.s_exp_sub_c)
|
||||
|
||||
size = int(exp.y.size(0))
|
||||
prob_initial = prob_initial[torch.arange(size), exp.y]
|
||||
prob_exp = prob_exp[torch.arange(size), exp.y]
|
||||
|
||||
return torch.mean(
|
||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||
).item()
|
||||
|
||||
def _fidelity_minus_prob(self, exp: Explanation) -> float:
|
||||
self._score_check()
|
||||
prob_initial = softmax(self.s_initial_data)
|
||||
prob_exp = softmax(self.s_exp_sub)
|
||||
|
||||
size = int(exp.y.size(0))
|
||||
prob_initial = prob_initial[torch.arange(size), exp.y]
|
||||
prob_exp = prob_exp[torch.arange(size), exp.y]
|
||||
|
||||
return torch.mean(
|
||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||
).item()
|
||||
|
||||
def score(self, exp):
|
||||
self.exp_sub = exp.get_explanation_subgraph()
|
||||
@ -53,14 +100,16 @@ class Fidelity(Metric):
|
||||
def load_metric(name):
|
||||
if name in self.authorized_metric:
|
||||
if name == "fidelity_plus":
|
||||
|
||||
self.metric = eval("sklearn.metric.{name}")
|
||||
self.metric = lambda exp: self._fidelity_plus(exp)
|
||||
if name == "fidelity_minus":
|
||||
self.metric = lambda exp: self._fidelity_minus(exp)
|
||||
if name == "fidelity_plus_prob":
|
||||
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
||||
if name == "fidelity_minus_prob":
|
||||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
|
||||
def compute(self, mask, target: Tensor) -> float:
|
||||
if mask.type() == torch.bool and target.type() == torch.bool:
|
||||
return self.metric(y_pred=mask, y_true=target)
|
||||
return self.metric
|
||||
|
||||
def __call__(self, exp: Explanation):
|
||||
pass
|
||||
return self.metric(exp)
|
||||
|
8
explaining_framework/metric/sparsity.py
Normal file
8
explaining_framework/metric/sparsity.py
Normal file
@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Sparsity(Metric):
|
||||
def __init__(self,name):
|
||||
super().__init__(name=name,model=None)
|
||||
|
||||
def
|
Loading…
Reference in New Issue
Block a user