explaining_framework/explaining_framework/metric/fidelity.py
2023-01-02 23:37:40 +01:00

196 lines
6.5 KiB
Python

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import KLDivLoss, Softmax
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from explaining_framework.metric.base import Metric
NUM_CLASS = cfg.share.dim_out
def softmax(data):
return Softmax(dim=1)(data)
def kl(data1, data2):
kld = KLDivLoss(reduction="batchmean")
return kld(data1, data2)
class Fidelity(Metric):
def __init__(self, name: str, model: torch.nn.Module):
super().__init__(name=name, model=model)
self.authorized_metric = [
"fidelity_plus",
"fidelity_minus",
"fidelity_plus_prob",
"fidelity_minus_prob",
"infidelity_KL",
"characterization",
"characterization_prob",
]
self.exp_sub = None
self.exp_sub_c = None
self.s_exp_sub = None
self.s_exp_sub_c = None
self.s_initial_data = None
self.metric = self.load_metric(name)
def _score_check(self):
if any(
[
attr is None
for attr in [
self.exp_sub,
self.exp_sub_c,
self.s_exp_sub,
self.s_exp_sub_c,
self.s_initial_data,
]
]
):
self.score(exp)
def _fidelity_plus(self, exp: Explanation) -> float:
self._score_check()
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
return torch.mean(
(exp.y == inferred_class_initial).float()
- (exp.y == inferred_class_exp).float()
).item()
def _fidelity_minus(self, exp: Explanation) -> float:
self._score_check()
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
return torch.mean(
(exp.y == inferred_class_initial).float()
- (exp.y == inferred_class_exp).float()
).item()
def _fidelity_plus_prob(self, exp: Explanation) -> float:
self._score_check()
# one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS)
prob_initial = softmax(self.s_initial_data)
prob_exp = softmax(self.s_exp_sub_c)
size = int(exp.y.size(0))
prob_initial = prob_initial[torch.arange(size), exp.y]
prob_exp = prob_exp[torch.arange(size), exp.y]
return torch.mean(
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
).item()
def _fidelity_minus_prob(self, exp: Explanation) -> float:
self._score_check()
prob_initial = softmax(self.s_initial_data)
prob_exp = softmax(self.s_exp_sub)
size = int(exp.y.size(0))
prob_initial = prob_initial[torch.arange(size), exp.y]
prob_exp = prob_exp[torch.arange(size), exp.y]
return torch.mean(
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
).item()
def _infidelity_KL(self, exp: Explanation) -> float:
self._score_check()
prob_initial = softmax(self.s_initial_data)
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
def _characterization_prob(
self,
exp: Explanation,
pos_weight: float = 0.5,
neg_weight: float = 0.5,
) -> Tensor:
if (pos_weight + neg_weight) != 1.0:
raise ValueError(
f"The weights need to sum up to 1 "
f"(got {pos_weight} and {neg_weight})"
)
pos_fidelity = self._fidelity_plus_prob(exp)
neg_fidelity = self._fidelity_minus_prob(exp)
if (
pos_fidelity == 0
or pos_fidelity == 1
or neg_fidelity == 0
or neg_fidelity == 1
):
return None
else:
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
if denom == 0:
return None
else:
return 1.0 / denom
def _characterization(
self,
exp: Explanation,
pos_weight: float = 0.5,
neg_weight: float = 0.5,
) -> Tensor:
if (pos_weight + neg_weight) != 1.0:
raise ValueError(
f"The weights need to sum up to 1 "
f"(got {pos_weight} and {neg_weight})"
)
pos_fidelity = self._fidelity_plus(exp)
neg_fidelity = self._fidelity_minus(exp)
if (
pos_fidelity == 0
or pos_fidelity == 1
or neg_fidelity == 0
or neg_fidelity == 1
):
return None
else:
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
if denom == 0:
return None
else:
return 1.0 / denom
def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph()
self.exp_sub_c = exp.get_complement_subgraph()
self.s_exp_sub = self.get_prediction(
x=self.exp_sub.x, edge_index=self.exp_sub.edge_index
)
self.s_exp_sub_c = self.get_prediction(
x=self.exp_sub_c.x, edge_index=self.exp_sub_c.edge_index
)
self.s_initial_data = self.get_prediction(x=exp.x, edge_index=exp.edge_index)
def load_metric(self, name):
if name in self.authorized_metric:
if name == "fidelity_plus":
self.metric = lambda exp: self._fidelity_plus(exp)
if name == "fidelity_minus":
self.metric = lambda exp: self._fidelity_minus(exp)
if name == "fidelity_plus_prob":
self.metric = lambda exp: self._fidelity_plus_prob(exp)
if name == "fidelity_minus_prob":
self.metric = lambda exp: self._fidelity_minus_prob(exp)
if name == "infidelity_KL":
self.metric = lambda exp: self._infidelity_KL(exp)
if name == "characterization":
self.metric = lambda exp: self._characterization(exp)
if name == "characterization_prob":
self.metric = lambda exp: self._characterization_prob(exp)
else:
raise ValueError(f"{name} is not supported")
return self.metric
def forward(self, exp: Explanation):
self.score(exp)
return self.metric(exp)