New fixes + reformat
This commit is contained in:
parent
b73f087a6a
commit
2aacd9fddd
@ -0,0 +1,231 @@
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.explain import Explanation
|
||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||
ModelConfig, ModelMode,
|
||||
ModelTaskLevel)
|
||||
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
|
||||
to_captum_input, to_captum_model)
|
||||
|
||||
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||
GuidedGradCam, InputXGradient, IntegratedGradients,
|
||||
Lime, Occlusion, Saliency)
|
||||
|
||||
|
||||
def _load_FeatureAblation(model):
|
||||
return lambda model: FeatureAblation(model)
|
||||
|
||||
|
||||
def _load_LRP(model):
|
||||
# return lambda model: LRP(model)
|
||||
raise ValueError("Captum LRP is not supported yet")
|
||||
|
||||
|
||||
def _load_DeepLift(model):
|
||||
# return lambda model: DeepLift(model)
|
||||
raise ValueError("Captum DeepLift is not supported yet")
|
||||
|
||||
|
||||
def _load_DeepLiftShap(model):
|
||||
# return lambda model: DeepLiftShap(model)
|
||||
raise ValueError("Captum DeepLiftShap is not supported yet")
|
||||
|
||||
|
||||
def _load_FeaturePermutation(model):
|
||||
# return lambda model: FeaturePermutation(model)
|
||||
raise ValueError("Captum FeaturePermutation is not supported yet")
|
||||
|
||||
|
||||
def _load_GradientShap(model):
|
||||
# return lambda model: GradientShap(model)
|
||||
raise ValueError("Captum GradientShap is not supported yet")
|
||||
|
||||
|
||||
def _load_GuidedBackPropagation(model):
|
||||
return lambda model: GuidedBackprop(model)
|
||||
|
||||
|
||||
def _load_GuidedGradCam(model):
|
||||
# return lambda model: GuidedGradCam(model)
|
||||
raise ValueError("Captum GuidedGradCam is not supported yet")
|
||||
|
||||
|
||||
def _load_InputXGradient(model):
|
||||
return lambda model: InputXGradient(model)
|
||||
|
||||
|
||||
def _load_Lime(model):
|
||||
return lambda model: Lime(model)
|
||||
|
||||
|
||||
def _load_Saliency(model):
|
||||
return lambda model: Saliency(model)
|
||||
|
||||
|
||||
def _load_Occlusion(model):
|
||||
# return lambda model: Occlusion(model)
|
||||
raise ValueError("Captum Occlusion is not supported yet")
|
||||
|
||||
|
||||
def _load_IntegratedGradients(model):
|
||||
# return lambda model: IntegratedGradients(model)
|
||||
raise ValueError("Captum IntegratedGradients is not supported yet")
|
||||
|
||||
|
||||
class CaptumWrapper(ExplainerAlgorithm):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def supports(self) -> bool:
|
||||
task_level = self.model_config.task_level
|
||||
if self.name in [
|
||||
"LRP",
|
||||
"DeepLift",
|
||||
"DeepLiftShap",
|
||||
"FeatureAblation",
|
||||
"FeaturePermutation",
|
||||
"GradientShap",
|
||||
"GuidedBackPropagation",
|
||||
"GuidedGradCam",
|
||||
"InputXGradient",
|
||||
"IntegratedGradients",
|
||||
"Lime",
|
||||
"Occlusion",
|
||||
"Saliency",
|
||||
]:
|
||||
if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:
|
||||
logging.error(f"Task level '{task_level.value}' not supported")
|
||||
return False
|
||||
|
||||
edge_mask_type = self.explainer_config.edge_mask_type
|
||||
if edge_mask_type not in [MaskType.object, None]:
|
||||
logging.error(
|
||||
f"Edge mask type '{edge_mask_type.value}' not " f"supported"
|
||||
)
|
||||
return False
|
||||
|
||||
node_mask_type = self.explainer_config.node_mask_type
|
||||
if node_mask_type not in [
|
||||
MaskType.common_attributes,
|
||||
MaskType.object,
|
||||
MaskType.attributes,
|
||||
None,
|
||||
]:
|
||||
logging.error(
|
||||
f"Node mask type '{node_mask_type.value}' not " f"supported."
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_mask_type(self):
|
||||
|
||||
edge_mask_type = self.explainer_config.edge_mask_type
|
||||
node_mask_type = self.explainer_config.node_mask_type
|
||||
|
||||
if edge_mask_type is None and node_mask_type is None:
|
||||
raise ValueError("You need to provide a masking config")
|
||||
|
||||
if not edge_mask_type is None and node_mask_type is None:
|
||||
self.mask_type = "edge"
|
||||
|
||||
if edge_mask_type is None and not node_mask_type is None:
|
||||
self.mask_type = "node"
|
||||
|
||||
if not edge_mask_type is None and not node_mask_type is None:
|
||||
self.mask_type = "node_and_edge"
|
||||
|
||||
_raise_on_invalid_mask_type(self.mask_type)
|
||||
return self.mask_type
|
||||
|
||||
def _load_captum_method(self, model):
|
||||
|
||||
if self.name == "LRP":
|
||||
return _load_LRP(model)
|
||||
|
||||
elif self.name == "DeepLift":
|
||||
return _load_DeepLift(model)
|
||||
|
||||
elif self.name == "DeepLiftShap":
|
||||
return _load_DeepLiftShap(model)
|
||||
|
||||
elif self.name == "FeatureAblation":
|
||||
return _load_FeatureAblation(model)
|
||||
|
||||
elif self.name == "FeaturePermutation":
|
||||
return _load_FeaturePermutation(model)
|
||||
|
||||
elif self.name == "GradientShap":
|
||||
return _load_GradientShap(model)
|
||||
|
||||
elif self.name == "GuidedBackPropagation":
|
||||
return _load_GuidedBackPropagation(model)
|
||||
|
||||
elif self.name == "GuidedGradCam":
|
||||
return _load_GuidedGradCam(model)
|
||||
|
||||
elif self.name == "InputXGradient":
|
||||
return _load_InputXGradient(model)
|
||||
|
||||
elif self.name == "IntegratedGradients":
|
||||
return _load_IntegratedGradients(model)
|
||||
|
||||
elif self.name == "Lime":
|
||||
return _load_Lime(model)
|
||||
|
||||
elif self.name == "Occlusion":
|
||||
return _load_Occlusion(model)
|
||||
|
||||
elif self.name == "Saliency":
|
||||
return _load_Saliency(model)
|
||||
else:
|
||||
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
||||
|
||||
def _parse_attr(self, attr):
|
||||
if self.mask_type == "node":
|
||||
node_mask = attr
|
||||
edge_mask = None
|
||||
|
||||
if "edge" == mask_type:
|
||||
node_mask = None
|
||||
edge_mask = attr
|
||||
|
||||
if "node_and_mask" == mask_type:
|
||||
node_mask = attr[0]
|
||||
edge_mask = attr[1]
|
||||
|
||||
# TODO
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
target: int,
|
||||
**kwargs,
|
||||
):
|
||||
mask_type = self._get_mask_type()
|
||||
converted_model = to_captum_model(model, mask_type=mask_type, output=target)
|
||||
self.captum_method = self._load_captum_method(converted_model)
|
||||
inputs, additional_forward_args = to_captum_input(
|
||||
x, edge_index, mask_type=mask_type
|
||||
)
|
||||
if self.name in ["InputXGradient", "Lime", "Saliency", "GuidedBackPropagation"]:
|
||||
attr = self.captum_method.attribute(
|
||||
inputs=inputs,
|
||||
additional_forward_args=additional_forward_args,
|
||||
target=target,
|
||||
)
|
||||
elif self.name == "FeatureAblation":
|
||||
attr = self.captum_method.attribute(
|
||||
inputs=inputs,
|
||||
additional_forward_args=additional_forward_args,
|
||||
)
|
||||
node_mask, edge_mask = self._parse_attr(attr)
|
||||
|
||||
return Explanation(
|
||||
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
|
||||
)
|
Loading…
Reference in New Issue
Block a user