New fixes + reformat

This commit is contained in:
araison 2022-12-09 18:14:12 +01:00
parent b73f087a6a
commit 2aacd9fddd
1 changed files with 231 additions and 0 deletions

View File

@ -0,0 +1,231 @@
from torch_geometric.data import Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
to_captum_input, to_captum_model)
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
FeaturePermutation, GradientShap, GuidedBackprop,
GuidedGradCam, InputXGradient, IntegratedGradients,
Lime, Occlusion, Saliency)
def _load_FeatureAblation(model):
return lambda model: FeatureAblation(model)
def _load_LRP(model):
# return lambda model: LRP(model)
raise ValueError("Captum LRP is not supported yet")
def _load_DeepLift(model):
# return lambda model: DeepLift(model)
raise ValueError("Captum DeepLift is not supported yet")
def _load_DeepLiftShap(model):
# return lambda model: DeepLiftShap(model)
raise ValueError("Captum DeepLiftShap is not supported yet")
def _load_FeaturePermutation(model):
# return lambda model: FeaturePermutation(model)
raise ValueError("Captum FeaturePermutation is not supported yet")
def _load_GradientShap(model):
# return lambda model: GradientShap(model)
raise ValueError("Captum GradientShap is not supported yet")
def _load_GuidedBackPropagation(model):
return lambda model: GuidedBackprop(model)
def _load_GuidedGradCam(model):
# return lambda model: GuidedGradCam(model)
raise ValueError("Captum GuidedGradCam is not supported yet")
def _load_InputXGradient(model):
return lambda model: InputXGradient(model)
def _load_Lime(model):
return lambda model: Lime(model)
def _load_Saliency(model):
return lambda model: Saliency(model)
def _load_Occlusion(model):
# return lambda model: Occlusion(model)
raise ValueError("Captum Occlusion is not supported yet")
def _load_IntegratedGradients(model):
# return lambda model: IntegratedGradients(model)
raise ValueError("Captum IntegratedGradients is not supported yet")
class CaptumWrapper(ExplainerAlgorithm):
def __init__(self, name):
super().__init__()
self.name = name
def supports(self) -> bool:
task_level = self.model_config.task_level
if self.name in [
"LRP",
"DeepLift",
"DeepLiftShap",
"FeatureAblation",
"FeaturePermutation",
"GradientShap",
"GuidedBackPropagation",
"GuidedGradCam",
"InputXGradient",
"IntegratedGradients",
"Lime",
"Occlusion",
"Saliency",
]:
if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:
logging.error(f"Task level '{task_level.value}' not supported")
return False
edge_mask_type = self.explainer_config.edge_mask_type
if edge_mask_type not in [MaskType.object, None]:
logging.error(
f"Edge mask type '{edge_mask_type.value}' not " f"supported"
)
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [
MaskType.common_attributes,
MaskType.object,
MaskType.attributes,
None,
]:
logging.error(
f"Node mask type '{node_mask_type.value}' not " f"supported."
)
return False
return True
def _get_mask_type(self):
edge_mask_type = self.explainer_config.edge_mask_type
node_mask_type = self.explainer_config.node_mask_type
if edge_mask_type is None and node_mask_type is None:
raise ValueError("You need to provide a masking config")
if not edge_mask_type is None and node_mask_type is None:
self.mask_type = "edge"
if edge_mask_type is None and not node_mask_type is None:
self.mask_type = "node"
if not edge_mask_type is None and not node_mask_type is None:
self.mask_type = "node_and_edge"
_raise_on_invalid_mask_type(self.mask_type)
return self.mask_type
def _load_captum_method(self, model):
if self.name == "LRP":
return _load_LRP(model)
elif self.name == "DeepLift":
return _load_DeepLift(model)
elif self.name == "DeepLiftShap":
return _load_DeepLiftShap(model)
elif self.name == "FeatureAblation":
return _load_FeatureAblation(model)
elif self.name == "FeaturePermutation":
return _load_FeaturePermutation(model)
elif self.name == "GradientShap":
return _load_GradientShap(model)
elif self.name == "GuidedBackPropagation":
return _load_GuidedBackPropagation(model)
elif self.name == "GuidedGradCam":
return _load_GuidedGradCam(model)
elif self.name == "InputXGradient":
return _load_InputXGradient(model)
elif self.name == "IntegratedGradients":
return _load_IntegratedGradients(model)
elif self.name == "Lime":
return _load_Lime(model)
elif self.name == "Occlusion":
return _load_Occlusion(model)
elif self.name == "Saliency":
return _load_Saliency(model)
else:
raise ValueError(f"{self.name} is not a supported Captum method yet !")
def _parse_attr(self, attr):
if self.mask_type == "node":
node_mask = attr
edge_mask = None
if "edge" == mask_type:
node_mask = None
edge_mask = attr
if "node_and_mask" == mask_type:
node_mask = attr[0]
edge_mask = attr[1]
# TODO
pass
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
target: int,
**kwargs,
):
mask_type = self._get_mask_type()
converted_model = to_captum_model(model, mask_type=mask_type, output=target)
self.captum_method = self._load_captum_method(converted_model)
inputs, additional_forward_args = to_captum_input(
x, edge_index, mask_type=mask_type
)
if self.name in ["InputXGradient", "Lime", "Saliency", "GuidedBackPropagation"]:
attr = self.captum_method.attribute(
inputs=inputs,
additional_forward_args=additional_forward_args,
target=target,
)
elif self.name == "FeatureAblation":
attr = self.captum_method.attribute(
inputs=inputs,
additional_forward_args=additional_forward_args,
)
node_mask, edge_mask = self._parse_attr(attr)
return Explanation(
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
)