New fixes + reformat
This commit is contained in:
		
							parent
							
								
									b73f087a6a
								
							
						
					
					
						commit
						2aacd9fddd
					
				
					 1 changed files with 231 additions and 0 deletions
				
			
		|  | @ -0,0 +1,231 @@ | |||
| from torch_geometric.data import Data | ||||
| from torch_geometric.explain import Explanation | ||||
| from torch_geometric.explain.algorithm.base import ExplainerAlgorithm | ||||
| from torch_geometric.explain.config import (ExplainerConfig, MaskType, | ||||
|                                             ModelConfig, ModelMode, | ||||
|                                             ModelTaskLevel) | ||||
| from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type, | ||||
|                                               to_captum_input, to_captum_model) | ||||
| 
 | ||||
| from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation, | ||||
|                          FeaturePermutation, GradientShap, GuidedBackprop, | ||||
|                          GuidedGradCam, InputXGradient, IntegratedGradients, | ||||
|                          Lime, Occlusion, Saliency) | ||||
| 
 | ||||
| 
 | ||||
| def _load_FeatureAblation(model): | ||||
|     return lambda model: FeatureAblation(model) | ||||
| 
 | ||||
| 
 | ||||
| def _load_LRP(model): | ||||
|     # return lambda model: LRP(model) | ||||
|     raise ValueError("Captum LRP is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_DeepLift(model): | ||||
|     # return lambda model: DeepLift(model) | ||||
|     raise ValueError("Captum DeepLift is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_DeepLiftShap(model): | ||||
|     # return lambda model: DeepLiftShap(model) | ||||
|     raise ValueError("Captum DeepLiftShap is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_FeaturePermutation(model): | ||||
|     # return lambda model: FeaturePermutation(model) | ||||
|     raise ValueError("Captum FeaturePermutation is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_GradientShap(model): | ||||
|     # return lambda model: GradientShap(model) | ||||
|     raise ValueError("Captum GradientShap is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_GuidedBackPropagation(model): | ||||
|     return lambda model: GuidedBackprop(model) | ||||
| 
 | ||||
| 
 | ||||
| def _load_GuidedGradCam(model): | ||||
|     # return lambda model: GuidedGradCam(model) | ||||
|     raise ValueError("Captum GuidedGradCam is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_InputXGradient(model): | ||||
|     return lambda model: InputXGradient(model) | ||||
| 
 | ||||
| 
 | ||||
| def _load_Lime(model): | ||||
|     return lambda model: Lime(model) | ||||
| 
 | ||||
| 
 | ||||
| def _load_Saliency(model): | ||||
|     return lambda model: Saliency(model) | ||||
| 
 | ||||
| 
 | ||||
| def _load_Occlusion(model): | ||||
|     # return lambda model: Occlusion(model) | ||||
|     raise ValueError("Captum Occlusion is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| def _load_IntegratedGradients(model): | ||||
|     # return lambda model: IntegratedGradients(model) | ||||
|     raise ValueError("Captum IntegratedGradients is not supported yet") | ||||
| 
 | ||||
| 
 | ||||
| class CaptumWrapper(ExplainerAlgorithm): | ||||
|     def __init__(self, name): | ||||
|         super().__init__() | ||||
|         self.name = name | ||||
| 
 | ||||
|     def supports(self) -> bool: | ||||
|         task_level = self.model_config.task_level | ||||
|         if self.name in [ | ||||
|             "LRP", | ||||
|             "DeepLift", | ||||
|             "DeepLiftShap", | ||||
|             "FeatureAblation", | ||||
|             "FeaturePermutation", | ||||
|             "GradientShap", | ||||
|             "GuidedBackPropagation", | ||||
|             "GuidedGradCam", | ||||
|             "InputXGradient", | ||||
|             "IntegratedGradients", | ||||
|             "Lime", | ||||
|             "Occlusion", | ||||
|             "Saliency", | ||||
|         ]: | ||||
|             if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]: | ||||
|                 logging.error(f"Task level '{task_level.value}' not supported") | ||||
|                 return False | ||||
| 
 | ||||
|             edge_mask_type = self.explainer_config.edge_mask_type | ||||
|             if edge_mask_type not in [MaskType.object, None]: | ||||
|                 logging.error( | ||||
|                     f"Edge mask type '{edge_mask_type.value}' not " f"supported" | ||||
|                 ) | ||||
|                 return False | ||||
| 
 | ||||
|             node_mask_type = self.explainer_config.node_mask_type | ||||
|             if node_mask_type not in [ | ||||
|                 MaskType.common_attributes, | ||||
|                 MaskType.object, | ||||
|                 MaskType.attributes, | ||||
|                 None, | ||||
|             ]: | ||||
|                 logging.error( | ||||
|                     f"Node mask type '{node_mask_type.value}' not " f"supported." | ||||
|                 ) | ||||
|                 return False | ||||
| 
 | ||||
|         return True | ||||
| 
 | ||||
|     def _get_mask_type(self): | ||||
| 
 | ||||
|         edge_mask_type = self.explainer_config.edge_mask_type | ||||
|         node_mask_type = self.explainer_config.node_mask_type | ||||
| 
 | ||||
|         if edge_mask_type is None and node_mask_type is None: | ||||
|             raise ValueError("You need to provide a masking config") | ||||
| 
 | ||||
|         if not edge_mask_type is None and node_mask_type is None: | ||||
|             self.mask_type = "edge" | ||||
| 
 | ||||
|         if edge_mask_type is None and not node_mask_type is None: | ||||
|             self.mask_type = "node" | ||||
| 
 | ||||
|         if not edge_mask_type is None and not node_mask_type is None: | ||||
|             self.mask_type = "node_and_edge" | ||||
| 
 | ||||
|         _raise_on_invalid_mask_type(self.mask_type) | ||||
|         return self.mask_type | ||||
| 
 | ||||
|     def _load_captum_method(self, model): | ||||
| 
 | ||||
|         if self.name == "LRP": | ||||
|             return _load_LRP(model) | ||||
| 
 | ||||
|         elif self.name == "DeepLift": | ||||
|             return _load_DeepLift(model) | ||||
| 
 | ||||
|         elif self.name == "DeepLiftShap": | ||||
|             return _load_DeepLiftShap(model) | ||||
| 
 | ||||
|         elif self.name == "FeatureAblation": | ||||
|             return _load_FeatureAblation(model) | ||||
| 
 | ||||
|         elif self.name == "FeaturePermutation": | ||||
|             return _load_FeaturePermutation(model) | ||||
| 
 | ||||
|         elif self.name == "GradientShap": | ||||
|             return _load_GradientShap(model) | ||||
| 
 | ||||
|         elif self.name == "GuidedBackPropagation": | ||||
|             return _load_GuidedBackPropagation(model) | ||||
| 
 | ||||
|         elif self.name == "GuidedGradCam": | ||||
|             return _load_GuidedGradCam(model) | ||||
| 
 | ||||
|         elif self.name == "InputXGradient": | ||||
|             return _load_InputXGradient(model) | ||||
| 
 | ||||
|         elif self.name == "IntegratedGradients": | ||||
|             return _load_IntegratedGradients(model) | ||||
| 
 | ||||
|         elif self.name == "Lime": | ||||
|             return _load_Lime(model) | ||||
| 
 | ||||
|         elif self.name == "Occlusion": | ||||
|             return _load_Occlusion(model) | ||||
| 
 | ||||
|         elif self.name == "Saliency": | ||||
|             return _load_Saliency(model) | ||||
|         else: | ||||
|             raise ValueError(f"{self.name} is not a supported Captum method yet !") | ||||
| 
 | ||||
|     def _parse_attr(self, attr): | ||||
|         if self.mask_type == "node": | ||||
|             node_mask = attr | ||||
|             edge_mask = None | ||||
| 
 | ||||
|         if "edge" == mask_type: | ||||
|             node_mask = None | ||||
|             edge_mask = attr | ||||
| 
 | ||||
|         if "node_and_mask" == mask_type: | ||||
|             node_mask = attr[0] | ||||
|             edge_mask = attr[1] | ||||
| 
 | ||||
|         # TODO | ||||
|         pass | ||||
| 
 | ||||
|     def forward( | ||||
|         self, | ||||
|         model: torch.nn.Module, | ||||
|         x: Tensor, | ||||
|         edge_index: Tensor, | ||||
|         target: int, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         mask_type = self._get_mask_type() | ||||
|         converted_model = to_captum_model(model, mask_type=mask_type, output=target) | ||||
|         self.captum_method = self._load_captum_method(converted_model) | ||||
|         inputs, additional_forward_args = to_captum_input( | ||||
|             x, edge_index, mask_type=mask_type | ||||
|         ) | ||||
|         if self.name in ["InputXGradient", "Lime", "Saliency", "GuidedBackPropagation"]: | ||||
|             attr = self.captum_method.attribute( | ||||
|                 inputs=inputs, | ||||
|                 additional_forward_args=additional_forward_args, | ||||
|                 target=target, | ||||
|             ) | ||||
|         elif self.name == "FeatureAblation": | ||||
|             attr = self.captum_method.attribute( | ||||
|                 inputs=inputs, | ||||
|                 additional_forward_args=additional_forward_args, | ||||
|             ) | ||||
|         node_mask, edge_mask = self._parse_attr(attr) | ||||
| 
 | ||||
|         return Explanation( | ||||
|             x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask | ||||
|         ) | ||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison