248 lines
		
	
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			248 lines
		
	
	
	
		
			7.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| 
 | |
| import torch
 | |
| from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
 | |
|                          FeaturePermutation, GradientShap, GuidedBackprop,
 | |
|                          GuidedGradCam, InputXGradient, IntegratedGradients,
 | |
|                          Lime, Occlusion, Saliency)
 | |
| from torch import Tensor
 | |
| from torch_geometric.data import Data
 | |
| from torch_geometric.explain import Explanation
 | |
| from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
 | |
| from torch_geometric.explain.config import (ExplainerConfig, MaskType,
 | |
|                                             ModelConfig, ModelMode,
 | |
|                                             ModelTaskLevel)
 | |
| from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
 | |
|                                               to_captum_input, to_captum_model)
 | |
| 
 | |
| 
 | |
| def _load_FeatureAblation(model):
 | |
|     return FeatureAblation(model)
 | |
| 
 | |
| 
 | |
| def _load_LRP(model):
 | |
|     # return lambda model: LRP(model)
 | |
|     raise ValueError("Captum LRP is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_DeepLift(model):
 | |
|     # return lambda model: DeepLift(model)
 | |
|     raise ValueError("Captum DeepLift is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_DeepLiftShap(model):
 | |
|     # return lambda model: DeepLiftShap(model)
 | |
|     raise ValueError("Captum DeepLiftShap is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_FeaturePermutation(model):
 | |
|     # return lambda model: FeaturePermutation(model)
 | |
|     raise ValueError("Captum FeaturePermutation is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_GradientShap(model):
 | |
|     # return lambda model: GradientShap(model)
 | |
|     raise ValueError("Captum GradientShap is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_GuidedBackPropagation(model):
 | |
|     return GuidedBackprop(model)
 | |
| 
 | |
| 
 | |
| def _load_GuidedGradCam(model):
 | |
|     # return lambda model: GuidedGradCam(model)
 | |
|     raise ValueError("Captum GuidedGradCam is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_InputXGradient(model):
 | |
|     return InputXGradient(model)
 | |
| 
 | |
| 
 | |
| def _load_Lime(model):
 | |
|     return Lime(model)
 | |
| 
 | |
| 
 | |
| def _load_Saliency(model):
 | |
|     return Saliency(model)
 | |
| 
 | |
| 
 | |
| def _load_Occlusion(model):
 | |
|     # return lambda model: Occlusion(model)
 | |
|     raise ValueError("Captum Occlusion is not supported yet")
 | |
| 
 | |
| 
 | |
| def _load_IntegratedGradients(model):
 | |
|     # return lambda model: IntegratedGradients(model)
 | |
|     raise ValueError("Captum IntegratedGradients is not supported yet")
 | |
| 
 | |
| 
 | |
| class CaptumWrapper(ExplainerAlgorithm):
 | |
|     def __init__(self, name):
 | |
|         super().__init__()
 | |
|         self.name = name
 | |
| 
 | |
|     def supports(self) -> bool:
 | |
|         task_level = self.model_config.task_level
 | |
|         if self.name in [
 | |
|             "LRP",
 | |
|             "DeepLift",
 | |
|             "DeepLiftShap",
 | |
|             "FeatureAblation",
 | |
|             "FeaturePermutation",
 | |
|             "GradientShap",
 | |
|             "GuidedBackPropagation",
 | |
|             "GuidedGradCam",
 | |
|             "InputXGradient",
 | |
|             "IntegratedGradients",
 | |
|             "Lime",
 | |
|             "Occlusion",
 | |
|             "Saliency",
 | |
|         ]:
 | |
|             if task_level not in [ModelTaskLevel.graph]:
 | |
|                 logging.error(f"Task level '{task_level.value}' not supported")
 | |
|                 return False
 | |
| 
 | |
|             edge_mask_type = self.explainer_config.edge_mask_type
 | |
|             if edge_mask_type not in [MaskType.object, None]:
 | |
|                 logging.error(
 | |
|                     f"Edge mask type '{edge_mask_type.value}' not " f"supported"
 | |
|                 )
 | |
|                 return False
 | |
| 
 | |
|             node_mask_type = self.explainer_config.node_mask_type
 | |
|             if node_mask_type not in [
 | |
|                 MaskType.common_attributes,
 | |
|                 MaskType.object,
 | |
|                 MaskType.attributes,
 | |
|                 None,
 | |
|             ]:
 | |
|                 logging.error(
 | |
|                     f"Node mask type '{node_mask_type.value}' not " f"supported."
 | |
|                 )
 | |
|                 return False
 | |
| 
 | |
|         return True
 | |
| 
 | |
|     def _get_mask_type(self):
 | |
| 
 | |
|         edge_mask_type = self.explainer_config.edge_mask_type
 | |
|         node_mask_type = self.explainer_config.node_mask_type
 | |
| 
 | |
|         if edge_mask_type is None and node_mask_type is None:
 | |
|             raise ValueError("You need to provide a masking config")
 | |
| 
 | |
|         if not edge_mask_type is None and node_mask_type is None:
 | |
|             self.mask_type = "edge"
 | |
| 
 | |
|         if edge_mask_type is None and not node_mask_type is None:
 | |
|             self.mask_type = "node"
 | |
| 
 | |
|         if not edge_mask_type is None and not node_mask_type is None:
 | |
|             self.mask_type = "node_and_edge"
 | |
| 
 | |
|         _raise_on_invalid_mask_type(self.mask_type)
 | |
|         return self.mask_type
 | |
| 
 | |
|     def _load_captum_method(self, model):
 | |
| 
 | |
|         if self.name == "LRP":
 | |
|             return _load_LRP(model)
 | |
| 
 | |
|         elif self.name == "DeepLift":
 | |
|             return _load_DeepLift(model)
 | |
| 
 | |
|         elif self.name == "DeepLiftShap":
 | |
|             return _load_DeepLiftShap(model)
 | |
| 
 | |
|         elif self.name == "FeatureAblation":
 | |
|             return _load_FeatureAblation(model)
 | |
| 
 | |
|         elif self.name == "FeaturePermutation":
 | |
|             return _load_FeaturePermutation(model)
 | |
| 
 | |
|         elif self.name == "GradientShap":
 | |
|             return _load_GradientShap(model)
 | |
| 
 | |
|         elif self.name == "GuidedBackPropagation":
 | |
|             return _load_GuidedBackPropagation(model)
 | |
| 
 | |
|         elif self.name == "GuidedGradCam":
 | |
|             return _load_GuidedGradCam(model)
 | |
| 
 | |
|         elif self.name == "InputXGradient":
 | |
|             return _load_InputXGradient(model)
 | |
| 
 | |
|         elif self.name == "IntegratedGradients":
 | |
|             return _load_IntegratedGradients(model)
 | |
| 
 | |
|         elif self.name == "Lime":
 | |
|             return _load_Lime(model)
 | |
| 
 | |
|         elif self.name == "Occlusion":
 | |
|             return _load_Occlusion(model)
 | |
| 
 | |
|         elif self.name == "Saliency":
 | |
|             return _load_Saliency(model)
 | |
|         else:
 | |
|             raise ValueError(f"{self.name} is not a supported Captum method yet !")
 | |
| 
 | |
|     def _parse_attr(self, attr):
 | |
|         for i in range(len(attr)):
 | |
|             attr[i] = attr[i].squeeze()
 | |
|         if self.mask_type == "node":
 | |
|             node_mask = attr[0]
 | |
|             edge_mask = None
 | |
| 
 | |
|         if self.mask_type == "edge":
 | |
|             node_mask = None
 | |
|             edge_mask = attr[0]
 | |
| 
 | |
|         if self.mask_type == "node_and_edge":
 | |
|             node_mask = attr[0]
 | |
|             edge_mask = attr[1]
 | |
|         else:
 | |
|             raise ValueError
 | |
| 
 | |
|         edge_feat_mask = None
 | |
|         node_feat_mask = None
 | |
| 
 | |
|         return node_mask, edge_mask, node_feat_mask, edge_feat_mask
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         model: torch.nn.Module,
 | |
|         x: Tensor,
 | |
|         edge_index: Tensor,
 | |
|         index: int,
 | |
|         target: int,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         mask_type = self._get_mask_type()
 | |
|         converted_model = to_captum_model(model, mask_type=mask_type, output_idx=index)
 | |
|         self.captum_method = self._load_captum_method(converted_model)
 | |
|         inputs, additional_forward_args = to_captum_input(
 | |
|             x, edge_index, mask_type=mask_type
 | |
|         )
 | |
|         if self.name in [
 | |
|             "InputXGradient",
 | |
|             "Lime",
 | |
|             "Saliency",
 | |
|             "GuidedBackPropagation",
 | |
|             "FeatureAblation",
 | |
|         ]:
 | |
|             attr = self.captum_method.attribute(
 | |
|                 inputs=inputs,
 | |
|                 additional_forward_args=additional_forward_args,
 | |
|                 target=target,
 | |
|             )
 | |
|         node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
 | |
| 
 | |
|         return Explanation(
 | |
|             x=x,
 | |
|             edge_index=edge_index,
 | |
|             y=target,
 | |
|             edge_mask=edge_mask,
 | |
|             node_mask=node_mask,
 | |
|             node_feat_mask=node_feat_mask,
 | |
|             edge_feat_mask=edge_feat_mask,
 | |
|         )
 | 
