Fixing bugs and adding new metric
This commit is contained in:
parent
20db23d307
commit
e9ef1cca9a
|
@ -10,6 +10,8 @@ NUM_CLASS = 5
|
||||||
|
|
||||||
def softmax(data):
|
def softmax(data):
|
||||||
return Softmax(dim=1)(data)
|
return Softmax(dim=1)(data)
|
||||||
|
def kl(data1,data2):
|
||||||
|
return KLDivLoss(dim=1)(data1,data2)
|
||||||
|
|
||||||
|
|
||||||
class Fidelity(Metric):
|
class Fidelity(Metric):
|
||||||
|
@ -23,12 +25,12 @@ class Fidelity(Metric):
|
||||||
"infidelity_KL",
|
"infidelity_KL",
|
||||||
]
|
]
|
||||||
|
|
||||||
self.metric = self.load_metric(name)
|
|
||||||
self.exp_sub = None
|
self.exp_sub = None
|
||||||
self.exp_sub_c = None
|
self.exp_sub_c = None
|
||||||
self.s_exp_sub = None
|
self.s_exp_sub = None
|
||||||
self.s_exp_sub_c = None
|
self.s_exp_sub_c = None
|
||||||
self.s_initial_data = None
|
self.s_initial_data = None
|
||||||
|
self.metric = self.load_metric(name)
|
||||||
|
|
||||||
def _score_check(self):
|
def _score_check(self):
|
||||||
if any(
|
if any(
|
||||||
|
@ -90,6 +92,14 @@ class Fidelity(Metric):
|
||||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
||||||
).item()
|
).item()
|
||||||
|
|
||||||
|
def _infidelity_KL(self, exp:Explanation) -> float:
|
||||||
|
self._score_check()
|
||||||
|
prob_initial = softmax(self.s_initial_data)
|
||||||
|
prob_exp = softmax(self.s_exp_sub)
|
||||||
|
return torch.mean(1 - torch.exp(-kl(prob_exp,prob_initial))).item()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def score(self, exp):
|
def score(self, exp):
|
||||||
self.exp_sub = exp.get_explanation_subgraph()
|
self.exp_sub = exp.get_explanation_subgraph()
|
||||||
self.exp_sub_c = exp.get_complement_subgraph()
|
self.exp_sub_c = exp.get_complement_subgraph()
|
||||||
|
@ -107,9 +117,12 @@ class Fidelity(Metric):
|
||||||
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
||||||
if name == "fidelity_minus_prob":
|
if name == "fidelity_minus_prob":
|
||||||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||||
|
if name == "_infidelity_KL":
|
||||||
|
self.metric = lambda exp: self._infidelity_KL(exp)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name} is not supported")
|
raise ValueError(f"{name} is not supported")
|
||||||
return self.metric
|
return self.metric
|
||||||
|
|
||||||
def __call__(self, exp: Explanation):
|
def __call__(self, exp: Explanation):
|
||||||
|
self.score(exp)
|
||||||
return self.metric(exp)
|
return self.metric(exp)
|
||||||
|
|
Loading…
Reference in New Issue