New fixes and new features
This commit is contained in:
parent
45730328a7
commit
eca200fe88
2 changed files with 71 additions and 14 deletions
|
@ -1,8 +1,16 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn import KLDivLoss, Softmax
|
from torch.nn import KLDivLoss, Softmax
|
||||||
|
|
||||||
from base import Metric
|
from base import Metric
|
||||||
|
|
||||||
|
print("NUM CLASSES cfg dataset")
|
||||||
|
NUM_CLASS = 5
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(data):
|
||||||
|
return Softmax(dim=1)(data)
|
||||||
|
|
||||||
|
|
||||||
class Fidelity(Metric):
|
class Fidelity(Metric):
|
||||||
def __init__(name: str, model: torch.nn.Module, mask_type: str):
|
def __init__(name: str, model: torch.nn.Module, mask_type: str):
|
||||||
|
@ -12,10 +20,6 @@ class Fidelity(Metric):
|
||||||
"fidelity_minus",
|
"fidelity_minus",
|
||||||
"fidelity_plus_prob",
|
"fidelity_plus_prob",
|
||||||
"fidelity_minus_prob",
|
"fidelity_minus_prob",
|
||||||
"fidelity_plus_model",
|
|
||||||
"fidelity_minus_model",
|
|
||||||
"fidelity_plus_prob_model",
|
|
||||||
"fidelity_minus_prob_model",
|
|
||||||
"infidelity_KL",
|
"infidelity_KL",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -26,7 +30,7 @@ class Fidelity(Metric):
|
||||||
self.s_exp_sub_c = None
|
self.s_exp_sub_c = None
|
||||||
self.s_initial_data = None
|
self.s_initial_data = None
|
||||||
|
|
||||||
def _fidelity_plus(self, exp) -> float:
|
def _score_check(self):
|
||||||
if any(
|
if any(
|
||||||
[
|
[
|
||||||
attr is None
|
attr is None
|
||||||
|
@ -40,8 +44,51 @@ class Fidelity(Metric):
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.score(exp)
|
self.score(exp)
|
||||||
else:
|
|
||||||
fid = self.s_initial_data - self.s_exp_sub_c
|
def _fidelity_plus(self, exp: Explanation) -> float:
|
||||||
|
self._score_check()
|
||||||
|
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||||
|
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
|
||||||
|
return torch.mean(
|
||||||
|
(exp.y == inferred_class_initial).long()
|
||||||
|
- (exp.y == inferred_class_exp).long()
|
||||||
|
).item()
|
||||||
|
|
||||||
|
def _fidelity_minus(self, exp: Explanation) -> float:
|
||||||
|
self._score_check()
|
||||||
|
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||||
|
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
|
||||||
|
return torch.mean(
|
||||||
|
(exp.y == inferred_class_initial).long()
|
||||||
|
- (exp.y == inferred_class_exp).long()
|
||||||
|
).item()
|
||||||
|
|
||||||
|
def _fidelity_plus_prob(self, exp: Explanation) -> float:
|
||||||
|
self._score_check()
|
||||||
|
# one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS)
|
||||||
|
prob_initial = softmax(self.s_initial_data)
|
||||||
|
prob_exp = softmax(self.s_exp_sub_c)
|
||||||
|
|
||||||
|
size = int(exp.y.size(0))
|
||||||
|
prob_initial = prob_initial[torch.arange(size), exp.y]
|
||||||
|
prob_exp = prob_exp[torch.arange(size), exp.y]
|
||||||
|
|
||||||
|
return torch.mean(
|
||||||
|
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||||
|
).item()
|
||||||
|
|
||||||
|
def _fidelity_minus_prob(self, exp: Explanation) -> float:
|
||||||
|
self._score_check()
|
||||||
|
prob_initial = softmax(self.s_initial_data)
|
||||||
|
prob_exp = softmax(self.s_exp_sub)
|
||||||
|
|
||||||
|
size = int(exp.y.size(0))
|
||||||
|
prob_initial = prob_initial[torch.arange(size), exp.y]
|
||||||
|
prob_exp = prob_exp[torch.arange(size), exp.y]
|
||||||
|
|
||||||
|
return torch.mean(
|
||||||
|
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||||
|
).item()
|
||||||
|
|
||||||
def score(self, exp):
|
def score(self, exp):
|
||||||
self.exp_sub = exp.get_explanation_subgraph()
|
self.exp_sub = exp.get_explanation_subgraph()
|
||||||
|
@ -53,14 +100,16 @@ class Fidelity(Metric):
|
||||||
def load_metric(name):
|
def load_metric(name):
|
||||||
if name in self.authorized_metric:
|
if name in self.authorized_metric:
|
||||||
if name == "fidelity_plus":
|
if name == "fidelity_plus":
|
||||||
|
self.metric = lambda exp: self._fidelity_plus(exp)
|
||||||
self.metric = eval("sklearn.metric.{name}")
|
if name == "fidelity_minus":
|
||||||
|
self.metric = lambda exp: self._fidelity_minus(exp)
|
||||||
|
if name == "fidelity_plus_prob":
|
||||||
|
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
||||||
|
if name == "fidelity_minus_prob":
|
||||||
|
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name} is not supported")
|
raise ValueError(f"{name} is not supported")
|
||||||
|
return self.metric
|
||||||
def compute(self, mask, target: Tensor) -> float:
|
|
||||||
if mask.type() == torch.bool and target.type() == torch.bool:
|
|
||||||
return self.metric(y_pred=mask, y_true=target)
|
|
||||||
|
|
||||||
def __call__(self, exp: Explanation):
|
def __call__(self, exp: Explanation):
|
||||||
pass
|
return self.metric(exp)
|
||||||
|
|
8
explaining_framework/metric/sparsity.py
Normal file
8
explaining_framework/metric/sparsity.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Sparsity(Metric):
|
||||||
|
def __init__(self,name):
|
||||||
|
super().__init__(name=name,model=None)
|
||||||
|
|
||||||
|
def
|
Loading…
Add table
Reference in a new issue