Fixing bugs and adding new metric
This commit is contained in:
parent
20db23d307
commit
e9ef1cca9a
@ -10,6 +10,8 @@ NUM_CLASS = 5
|
||||
|
||||
def softmax(data):
|
||||
return Softmax(dim=1)(data)
|
||||
def kl(data1,data2):
|
||||
return KLDivLoss(dim=1)(data1,data2)
|
||||
|
||||
|
||||
class Fidelity(Metric):
|
||||
@ -23,12 +25,12 @@ class Fidelity(Metric):
|
||||
"infidelity_KL",
|
||||
]
|
||||
|
||||
self.metric = self.load_metric(name)
|
||||
self.exp_sub = None
|
||||
self.exp_sub_c = None
|
||||
self.s_exp_sub = None
|
||||
self.s_exp_sub_c = None
|
||||
self.s_initial_data = None
|
||||
self.metric = self.load_metric(name)
|
||||
|
||||
def _score_check(self):
|
||||
if any(
|
||||
@ -90,6 +92,14 @@ class Fidelity(Metric):
|
||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||
).item()
|
||||
|
||||
def _infidelity_KL(self, exp:Explanation) -> float:
|
||||
self._score_check()
|
||||
prob_initial = softmax(self.s_initial_data)
|
||||
prob_exp = softmax(self.s_exp_sub)
|
||||
return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item()
|
||||
|
||||
|
||||
|
||||
def score(self, exp):
|
||||
self.exp_sub = exp.get_explanation_subgraph()
|
||||
self.exp_sub_c = exp.get_complement_subgraph()
|
||||
@ -107,9 +117,12 @@ class Fidelity(Metric):
|
||||
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
||||
if name == "fidelity_minus_prob":
|
||||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||
if name == "_infidelity_KL":
|
||||
self.metric = lambda exp: self._infidelity_KL(exp)
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
return self.metric
|
||||
|
||||
def __call__(self, exp: Explanation):
|
||||
self.score(exp)
|
||||
return self.metric(exp)
|
||||
|
Loading…
Reference in New Issue
Block a user