Fixing bugs and adding new metric

This commit is contained in:
araison 2022-12-17 17:30:27 +01:00
parent 20db23d307
commit e9ef1cca9a
1 changed files with 14 additions and 1 deletions

View File

@ -10,6 +10,8 @@ NUM_CLASS = 5
def softmax(data): def softmax(data):
return Softmax(dim=1)(data) return Softmax(dim=1)(data)
def kl(data1,data2):
return KLDivLoss(dim=1)(data1,data2)
class Fidelity(Metric): class Fidelity(Metric):
@ -23,12 +25,12 @@ class Fidelity(Metric):
"infidelity_KL", "infidelity_KL",
] ]
self.metric = self.load_metric(name)
self.exp_sub = None self.exp_sub = None
self.exp_sub_c = None self.exp_sub_c = None
self.s_exp_sub = None self.s_exp_sub = None
self.s_exp_sub_c = None self.s_exp_sub_c = None
self.s_initial_data = None self.s_initial_data = None
self.metric = self.load_metric(name)
def _score_check(self): def _score_check(self):
if any( if any(
@ -90,6 +92,14 @@ class Fidelity(Metric):
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1) torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
).item() ).item()
def _infidelity_KL(self, exp:Explanation) -> float:
self._score_check()
prob_initial = softmax(self.s_initial_data)
prob_exp = softmax(self.s_exp_sub)
return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item()
def score(self, exp): def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph() self.exp_sub = exp.get_explanation_subgraph()
self.exp_sub_c = exp.get_complement_subgraph() self.exp_sub_c = exp.get_complement_subgraph()
@ -107,9 +117,12 @@ class Fidelity(Metric):
self.metric = lambda exp: self._fidelity_plus_prob(exp) self.metric = lambda exp: self._fidelity_plus_prob(exp)
if name == "fidelity_minus_prob": if name == "fidelity_minus_prob":
self.metric = lambda exp: self._fidelity_minus_prob(exp) self.metric = lambda exp: self._fidelity_minus_prob(exp)
if name == "_infidelity_KL":
self.metric = lambda exp: self._infidelity_KL(exp)
else: else:
raise ValueError(f"{name} is not supported") raise ValueError(f"{name} is not supported")
return self.metric return self.metric
def __call__(self, exp: Explanation): def __call__(self, exp: Explanation):
self.score(exp)
return self.metric(exp) return self.metric(exp)