explaining_framework/explaining_framework/explainers/wrappers/from_captum.py
2022-12-29 22:00:39 +01:00

248 lines
7.4 KiB
Python

import logging
import torch
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
FeaturePermutation, GradientShap, GuidedBackprop,
GuidedGradCam, InputXGradient, IntegratedGradients,
Lime, Occlusion, Saliency)
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
to_captum_input, to_captum_model)
def _load_FeatureAblation(model):
return FeatureAblation(model)
def _load_LRP(model):
# return lambda model: LRP(model)
raise ValueError("Captum LRP is not supported yet")
def _load_DeepLift(model):
# return lambda model: DeepLift(model)
raise ValueError("Captum DeepLift is not supported yet")
def _load_DeepLiftShap(model):
# return lambda model: DeepLiftShap(model)
raise ValueError("Captum DeepLiftShap is not supported yet")
def _load_FeaturePermutation(model):
# return lambda model: FeaturePermutation(model)
raise ValueError("Captum FeaturePermutation is not supported yet")
def _load_GradientShap(model):
# return lambda model: GradientShap(model)
raise ValueError("Captum GradientShap is not supported yet")
def _load_GuidedBackPropagation(model):
return GuidedBackprop(model)
def _load_GuidedGradCam(model):
# return lambda model: GuidedGradCam(model)
raise ValueError("Captum GuidedGradCam is not supported yet")
def _load_InputXGradient(model):
return InputXGradient(model)
def _load_Lime(model):
return Lime(model)
def _load_Saliency(model):
return Saliency(model)
def _load_Occlusion(model):
# return lambda model: Occlusion(model)
raise ValueError("Captum Occlusion is not supported yet")
def _load_IntegratedGradients(model):
# return lambda model: IntegratedGradients(model)
raise ValueError("Captum IntegratedGradients is not supported yet")
class CaptumWrapper(ExplainerAlgorithm):
def __init__(self, name):
super().__init__()
self.name = name
def supports(self) -> bool:
task_level = self.model_config.task_level
if self.name in [
"LRP",
"DeepLift",
"DeepLiftShap",
"FeatureAblation",
"FeaturePermutation",
"GradientShap",
"GuidedBackPropagation",
"GuidedGradCam",
"InputXGradient",
"IntegratedGradients",
"Lime",
"Occlusion",
"Saliency",
]:
if task_level not in [ModelTaskLevel.graph]:
logging.error(f"Task level '{task_level.value}' not supported")
return False
edge_mask_type = self.explainer_config.edge_mask_type
if edge_mask_type not in [MaskType.object, None]:
logging.error(
f"Edge mask type '{edge_mask_type.value}' not " f"supported"
)
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [
MaskType.common_attributes,
MaskType.object,
MaskType.attributes,
None,
]:
logging.error(
f"Node mask type '{node_mask_type.value}' not " f"supported."
)
return False
return True
def _get_mask_type(self):
edge_mask_type = self.explainer_config.edge_mask_type
node_mask_type = self.explainer_config.node_mask_type
if edge_mask_type is None and node_mask_type is None:
raise ValueError("You need to provide a masking config")
if not edge_mask_type is None and node_mask_type is None:
self.mask_type = "edge"
if edge_mask_type is None and not node_mask_type is None:
self.mask_type = "node"
if not edge_mask_type is None and not node_mask_type is None:
self.mask_type = "node_and_edge"
_raise_on_invalid_mask_type(self.mask_type)
return self.mask_type
def _load_captum_method(self, model):
if self.name == "LRP":
return _load_LRP(model)
elif self.name == "DeepLift":
return _load_DeepLift(model)
elif self.name == "DeepLiftShap":
return _load_DeepLiftShap(model)
elif self.name == "FeatureAblation":
return _load_FeatureAblation(model)
elif self.name == "FeaturePermutation":
return _load_FeaturePermutation(model)
elif self.name == "GradientShap":
return _load_GradientShap(model)
elif self.name == "GuidedBackPropagation":
return _load_GuidedBackPropagation(model)
elif self.name == "GuidedGradCam":
return _load_GuidedGradCam(model)
elif self.name == "InputXGradient":
return _load_InputXGradient(model)
elif self.name == "IntegratedGradients":
return _load_IntegratedGradients(model)
elif self.name == "Lime":
return _load_Lime(model)
elif self.name == "Occlusion":
return _load_Occlusion(model)
elif self.name == "Saliency":
return _load_Saliency(model)
else:
raise ValueError(f"{self.name} is not a supported Captum method yet !")
def _parse_attr(self, attr):
for i in range(len(attr)):
attr[i] = attr[i].squeeze()
if self.mask_type == "node":
node_mask = attr[0]
edge_mask = None
if self.mask_type == "edge":
node_mask = None
edge_mask = attr[0]
if self.mask_type == "node_and_edge":
node_mask = attr[0]
edge_mask = attr[1]
else:
raise ValueError
edge_feat_mask = None
node_feat_mask = None
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
index: int,
target: int,
**kwargs,
):
mask_type = self._get_mask_type()
converted_model = to_captum_model(model, mask_type=mask_type, output_idx=index)
self.captum_method = self._load_captum_method(converted_model)
inputs, additional_forward_args = to_captum_input(
x, edge_index, mask_type=mask_type
)
if self.name in [
"InputXGradient",
"Lime",
"Saliency",
"GuidedBackPropagation",
"FeatureAblation",
]:
attr = self.captum_method.attribute(
inputs=inputs,
additional_forward_args=additional_forward_args,
target=target,
)
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
return Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)