196 lines
6.5 KiB
Python
196 lines
6.5 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.nn import KLDivLoss, Softmax
|
|
from torch_geometric.explain.explanation import Explanation
|
|
from torch_geometric.graphgym.config import cfg
|
|
|
|
from explaining_framework.metric.base import Metric
|
|
|
|
NUM_CLASS = cfg.share.dim_out
|
|
|
|
|
|
def softmax(data):
|
|
return Softmax(dim=1)(data)
|
|
|
|
|
|
def kl(data1, data2):
|
|
kld = KLDivLoss(reduction="batchmean")
|
|
return kld(data1, data2)
|
|
|
|
|
|
class Fidelity(Metric):
|
|
def __init__(self, name: str, model: torch.nn.Module):
|
|
super().__init__(name=name, model=model)
|
|
self.authorized_metric = [
|
|
"fidelity_plus",
|
|
"fidelity_minus",
|
|
"fidelity_plus_prob",
|
|
"fidelity_minus_prob",
|
|
"infidelity_KL",
|
|
"characterization",
|
|
"characterization_prob",
|
|
]
|
|
|
|
self.exp_sub = None
|
|
self.exp_sub_c = None
|
|
self.s_exp_sub = None
|
|
self.s_exp_sub_c = None
|
|
self.s_initial_data = None
|
|
self.metric = self.load_metric(name)
|
|
|
|
def _score_check(self):
|
|
if any(
|
|
[
|
|
attr is None
|
|
for attr in [
|
|
self.exp_sub,
|
|
self.exp_sub_c,
|
|
self.s_exp_sub,
|
|
self.s_exp_sub_c,
|
|
self.s_initial_data,
|
|
]
|
|
]
|
|
):
|
|
self.score(exp)
|
|
|
|
def _fidelity_plus(self, exp: Explanation) -> float:
|
|
self._score_check()
|
|
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
|
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
|
|
return torch.mean(
|
|
(exp.y == inferred_class_initial).float()
|
|
- (exp.y == inferred_class_exp).float()
|
|
).item()
|
|
|
|
def _fidelity_minus(self, exp: Explanation) -> float:
|
|
self._score_check()
|
|
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
|
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
|
|
return torch.mean(
|
|
(exp.y == inferred_class_initial).float()
|
|
- (exp.y == inferred_class_exp).float()
|
|
).item()
|
|
|
|
def _fidelity_plus_prob(self, exp: Explanation) -> float:
|
|
self._score_check()
|
|
# one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS)
|
|
prob_initial = softmax(self.s_initial_data)
|
|
prob_exp = softmax(self.s_exp_sub_c)
|
|
|
|
size = int(exp.y.size(0))
|
|
prob_initial = prob_initial[torch.arange(size), exp.y]
|
|
prob_exp = prob_exp[torch.arange(size), exp.y]
|
|
|
|
return torch.mean(
|
|
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
|
).item()
|
|
|
|
def _fidelity_minus_prob(self, exp: Explanation) -> float:
|
|
self._score_check()
|
|
prob_initial = softmax(self.s_initial_data)
|
|
prob_exp = softmax(self.s_exp_sub)
|
|
|
|
size = int(exp.y.size(0))
|
|
prob_initial = prob_initial[torch.arange(size), exp.y]
|
|
prob_exp = prob_exp[torch.arange(size), exp.y]
|
|
|
|
return torch.mean(
|
|
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
|
).item()
|
|
|
|
def _infidelity_KL(self, exp: Explanation) -> float:
|
|
self._score_check()
|
|
prob_initial = softmax(self.s_initial_data)
|
|
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
|
|
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
|
|
|
|
def _characterization_prob(
|
|
self,
|
|
exp: Explanation,
|
|
pos_weight: float = 0.5,
|
|
neg_weight: float = 0.5,
|
|
) -> Tensor:
|
|
if (pos_weight + neg_weight) != 1.0:
|
|
raise ValueError(
|
|
f"The weights need to sum up to 1 "
|
|
f"(got {pos_weight} and {neg_weight})"
|
|
)
|
|
pos_fidelity = self._fidelity_plus_prob(exp)
|
|
neg_fidelity = self._fidelity_minus_prob(exp)
|
|
if (
|
|
pos_fidelity == 0
|
|
or pos_fidelity == 1
|
|
or neg_fidelity == 0
|
|
or neg_fidelity == 1
|
|
):
|
|
return None
|
|
else:
|
|
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
|
if denom == 0:
|
|
return None
|
|
else:
|
|
return 1.0 / denom
|
|
|
|
def _characterization(
|
|
self,
|
|
exp: Explanation,
|
|
pos_weight: float = 0.5,
|
|
neg_weight: float = 0.5,
|
|
) -> Tensor:
|
|
if (pos_weight + neg_weight) != 1.0:
|
|
raise ValueError(
|
|
f"The weights need to sum up to 1 "
|
|
f"(got {pos_weight} and {neg_weight})"
|
|
)
|
|
pos_fidelity = self._fidelity_plus(exp)
|
|
neg_fidelity = self._fidelity_minus(exp)
|
|
|
|
if (
|
|
pos_fidelity == 0
|
|
or pos_fidelity == 1
|
|
or neg_fidelity == 0
|
|
or neg_fidelity == 1
|
|
):
|
|
return None
|
|
else:
|
|
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
|
if denom == 0:
|
|
return None
|
|
else:
|
|
return 1.0 / denom
|
|
|
|
def score(self, exp):
|
|
self.exp_sub = exp.get_explanation_subgraph()
|
|
self.exp_sub_c = exp.get_complement_subgraph()
|
|
self.s_exp_sub = self.get_prediction(
|
|
x=self.exp_sub.x, edge_index=self.exp_sub.edge_index
|
|
)
|
|
self.s_exp_sub_c = self.get_prediction(
|
|
x=self.exp_sub_c.x, edge_index=self.exp_sub_c.edge_index
|
|
)
|
|
self.s_initial_data = self.get_prediction(x=exp.x, edge_index=exp.edge_index)
|
|
|
|
def load_metric(self, name):
|
|
if name in self.authorized_metric:
|
|
if name == "fidelity_plus":
|
|
self.metric = lambda exp: self._fidelity_plus(exp)
|
|
if name == "fidelity_minus":
|
|
self.metric = lambda exp: self._fidelity_minus(exp)
|
|
if name == "fidelity_plus_prob":
|
|
self.metric = lambda exp: self._fidelity_plus_prob(exp)
|
|
if name == "fidelity_minus_prob":
|
|
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
|
if name == "infidelity_KL":
|
|
self.metric = lambda exp: self._infidelity_KL(exp)
|
|
if name == "characterization":
|
|
self.metric = lambda exp: self._characterization(exp)
|
|
if name == "characterization_prob":
|
|
self.metric = lambda exp: self._characterization_prob(exp)
|
|
else:
|
|
raise ValueError(f"{name} is not supported")
|
|
return self.metric
|
|
|
|
def forward(self, exp: Explanation):
|
|
self.score(exp)
|
|
return self.metric(exp)
|