New fixes + reformat
This commit is contained in:
parent
b73f087a6a
commit
2aacd9fddd
|
@ -0,0 +1,231 @@
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.explain import Explanation
|
||||||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
|
ModelConfig, ModelMode,
|
||||||
|
ModelTaskLevel)
|
||||||
|
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
|
||||||
|
to_captum_input, to_captum_model)
|
||||||
|
|
||||||
|
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||||
|
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||||
|
GuidedGradCam, InputXGradient, IntegratedGradients,
|
||||||
|
Lime, Occlusion, Saliency)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_FeatureAblation(model):
|
||||||
|
return lambda model: FeatureAblation(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_LRP(model):
|
||||||
|
# return lambda model: LRP(model)
|
||||||
|
raise ValueError("Captum LRP is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_DeepLift(model):
|
||||||
|
# return lambda model: DeepLift(model)
|
||||||
|
raise ValueError("Captum DeepLift is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_DeepLiftShap(model):
|
||||||
|
# return lambda model: DeepLiftShap(model)
|
||||||
|
raise ValueError("Captum DeepLiftShap is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_FeaturePermutation(model):
|
||||||
|
# return lambda model: FeaturePermutation(model)
|
||||||
|
raise ValueError("Captum FeaturePermutation is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GradientShap(model):
|
||||||
|
# return lambda model: GradientShap(model)
|
||||||
|
raise ValueError("Captum GradientShap is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GuidedBackPropagation(model):
|
||||||
|
return lambda model: GuidedBackprop(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GuidedGradCam(model):
|
||||||
|
# return lambda model: GuidedGradCam(model)
|
||||||
|
raise ValueError("Captum GuidedGradCam is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_InputXGradient(model):
|
||||||
|
return lambda model: InputXGradient(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_Lime(model):
|
||||||
|
return lambda model: Lime(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_Saliency(model):
|
||||||
|
return lambda model: Saliency(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_Occlusion(model):
|
||||||
|
# return lambda model: Occlusion(model)
|
||||||
|
raise ValueError("Captum Occlusion is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_IntegratedGradients(model):
|
||||||
|
# return lambda model: IntegratedGradients(model)
|
||||||
|
raise ValueError("Captum IntegratedGradients is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
class CaptumWrapper(ExplainerAlgorithm):
|
||||||
|
def __init__(self, name):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def supports(self) -> bool:
|
||||||
|
task_level = self.model_config.task_level
|
||||||
|
if self.name in [
|
||||||
|
"LRP",
|
||||||
|
"DeepLift",
|
||||||
|
"DeepLiftShap",
|
||||||
|
"FeatureAblation",
|
||||||
|
"FeaturePermutation",
|
||||||
|
"GradientShap",
|
||||||
|
"GuidedBackPropagation",
|
||||||
|
"GuidedGradCam",
|
||||||
|
"InputXGradient",
|
||||||
|
"IntegratedGradients",
|
||||||
|
"Lime",
|
||||||
|
"Occlusion",
|
||||||
|
"Saliency",
|
||||||
|
]:
|
||||||
|
if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:
|
||||||
|
logging.error(f"Task level '{task_level.value}' not supported")
|
||||||
|
return False
|
||||||
|
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
if edge_mask_type not in [MaskType.object, None]:
|
||||||
|
logging.error(
|
||||||
|
f"Edge mask type '{edge_mask_type.value}' not " f"supported"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
if node_mask_type not in [
|
||||||
|
MaskType.common_attributes,
|
||||||
|
MaskType.object,
|
||||||
|
MaskType.attributes,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
logging.error(
|
||||||
|
f"Node mask type '{node_mask_type.value}' not " f"supported."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_mask_type(self):
|
||||||
|
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
|
||||||
|
if edge_mask_type is None and node_mask_type is None:
|
||||||
|
raise ValueError("You need to provide a masking config")
|
||||||
|
|
||||||
|
if not edge_mask_type is None and node_mask_type is None:
|
||||||
|
self.mask_type = "edge"
|
||||||
|
|
||||||
|
if edge_mask_type is None and not node_mask_type is None:
|
||||||
|
self.mask_type = "node"
|
||||||
|
|
||||||
|
if not edge_mask_type is None and not node_mask_type is None:
|
||||||
|
self.mask_type = "node_and_edge"
|
||||||
|
|
||||||
|
_raise_on_invalid_mask_type(self.mask_type)
|
||||||
|
return self.mask_type
|
||||||
|
|
||||||
|
def _load_captum_method(self, model):
|
||||||
|
|
||||||
|
if self.name == "LRP":
|
||||||
|
return _load_LRP(model)
|
||||||
|
|
||||||
|
elif self.name == "DeepLift":
|
||||||
|
return _load_DeepLift(model)
|
||||||
|
|
||||||
|
elif self.name == "DeepLiftShap":
|
||||||
|
return _load_DeepLiftShap(model)
|
||||||
|
|
||||||
|
elif self.name == "FeatureAblation":
|
||||||
|
return _load_FeatureAblation(model)
|
||||||
|
|
||||||
|
elif self.name == "FeaturePermutation":
|
||||||
|
return _load_FeaturePermutation(model)
|
||||||
|
|
||||||
|
elif self.name == "GradientShap":
|
||||||
|
return _load_GradientShap(model)
|
||||||
|
|
||||||
|
elif self.name == "GuidedBackPropagation":
|
||||||
|
return _load_GuidedBackPropagation(model)
|
||||||
|
|
||||||
|
elif self.name == "GuidedGradCam":
|
||||||
|
return _load_GuidedGradCam(model)
|
||||||
|
|
||||||
|
elif self.name == "InputXGradient":
|
||||||
|
return _load_InputXGradient(model)
|
||||||
|
|
||||||
|
elif self.name == "IntegratedGradients":
|
||||||
|
return _load_IntegratedGradients(model)
|
||||||
|
|
||||||
|
elif self.name == "Lime":
|
||||||
|
return _load_Lime(model)
|
||||||
|
|
||||||
|
elif self.name == "Occlusion":
|
||||||
|
return _load_Occlusion(model)
|
||||||
|
|
||||||
|
elif self.name == "Saliency":
|
||||||
|
return _load_Saliency(model)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
||||||
|
|
||||||
|
def _parse_attr(self, attr):
|
||||||
|
if self.mask_type == "node":
|
||||||
|
node_mask = attr
|
||||||
|
edge_mask = None
|
||||||
|
|
||||||
|
if "edge" == mask_type:
|
||||||
|
node_mask = None
|
||||||
|
edge_mask = attr
|
||||||
|
|
||||||
|
if "node_and_mask" == mask_type:
|
||||||
|
node_mask = attr[0]
|
||||||
|
edge_mask = attr[1]
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
target: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
mask_type = self._get_mask_type()
|
||||||
|
converted_model = to_captum_model(model, mask_type=mask_type, output=target)
|
||||||
|
self.captum_method = self._load_captum_method(converted_model)
|
||||||
|
inputs, additional_forward_args = to_captum_input(
|
||||||
|
x, edge_index, mask_type=mask_type
|
||||||
|
)
|
||||||
|
if self.name in ["InputXGradient", "Lime", "Saliency", "GuidedBackPropagation"]:
|
||||||
|
attr = self.captum_method.attribute(
|
||||||
|
inputs=inputs,
|
||||||
|
additional_forward_args=additional_forward_args,
|
||||||
|
target=target,
|
||||||
|
)
|
||||||
|
elif self.name == "FeatureAblation":
|
||||||
|
attr = self.captum_method.attribute(
|
||||||
|
inputs=inputs,
|
||||||
|
additional_forward_args=additional_forward_args,
|
||||||
|
)
|
||||||
|
node_mask, edge_mask = self._parse_attr(attr)
|
||||||
|
|
||||||
|
return Explanation(
|
||||||
|
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
|
||||||
|
)
|
Loading…
Reference in New Issue