New fixes

This commit is contained in:
araison 2022-12-13 13:23:01 +01:00
parent 8067185d1a
commit 6cf1d64d3a
2 changed files with 132 additions and 64 deletions

View File

@ -1,3 +1,9 @@
import torch
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
FeaturePermutation, GradientShap, GuidedBackprop,
GuidedGradCam, InputXGradient, IntegratedGradients,
Lime, Occlusion, Saliency)
from torch import Tensor
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.explain import Explanation from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
@ -7,14 +13,9 @@ from torch_geometric.explain.config import (ExplainerConfig, MaskType,
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type, from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
to_captum_input, to_captum_model) to_captum_input, to_captum_model)
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
FeaturePermutation, GradientShap, GuidedBackprop,
GuidedGradCam, InputXGradient, IntegratedGradients,
Lime, Occlusion, Saliency)
def _load_FeatureAblation(model): def _load_FeatureAblation(model):
return lambda model: FeatureAblation(model) return FeatureAblation(model)
def _load_LRP(model): def _load_LRP(model):
@ -43,7 +44,7 @@ def _load_GradientShap(model):
def _load_GuidedBackPropagation(model): def _load_GuidedBackPropagation(model):
return lambda model: GuidedBackprop(model) return GuidedBackprop(model)
def _load_GuidedGradCam(model): def _load_GuidedGradCam(model):
@ -52,15 +53,15 @@ def _load_GuidedGradCam(model):
def _load_InputXGradient(model): def _load_InputXGradient(model):
return lambda model: InputXGradient(model) return InputXGradient(model)
def _load_Lime(model): def _load_Lime(model):
return lambda model: Lime(model) return Lime(model)
def _load_Saliency(model): def _load_Saliency(model):
return lambda model: Saliency(model) return Saliency(model)
def _load_Occlusion(model): def _load_Occlusion(model):
@ -212,7 +213,7 @@ class CaptumWrapper(ExplainerAlgorithm):
**kwargs, **kwargs,
): ):
mask_type = self._get_mask_type() mask_type = self._get_mask_type()
converted_model = to_captum_model(model, mask_type=mask_type, output=target) converted_model = to_captum_model(model, mask_type=mask_type, output_idx=target)
self.captum_method = self._load_captum_method(converted_model) self.captum_method = self._load_captum_method(converted_model)
inputs, additional_forward_args = to_captum_input( inputs, additional_forward_args = to_captum_input(
x, edge_index, mask_type=mask_type x, edge_index, mask_type=mask_type

View File

@ -1,10 +1,7 @@
from torch_geometric.data import Data import inspect
from torch_geometric.explain import Explanation from typing import Dict, Optional, Tuple, Union
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
import torch
from graphxai.explainers.cam import CAM, GradCAM from graphxai.explainers.cam import CAM, GradCAM
from graphxai.explainers.gnn_explainer import GNNExplainer from graphxai.explainers.gnn_explainer import GNNExplainer
from graphxai.explainers.gnn_lrp import GNN_LRP from graphxai.explainers.gnn_lrp import GNN_LRP
@ -16,14 +13,22 @@ from graphxai.explainers.pg_explainer import PGExplainer
from graphxai.explainers.pgm_explainer import PGMExplainer from graphxai.explainers.pgm_explainer import PGMExplainer
from graphxai.explainers.random import RandomExplainer from graphxai.explainers.random import RandomExplainer
from graphxai.explainers.subgraphx import SubgraphX from graphxai.explainers.subgraphx import SubgraphX
from torch import Tensor
from torch.nn import CrossEntropyLoss, KLDivLoss, MSELoss
from torch_geometric.data import Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
def _load_CAM(model): def _load_CAM(model):
return lambda model: CAM(model) return CAM(model)
def _load_GradCAM(model): def _load_GradCAM(model):
return lambda model: GradCAM(model) return GradCAM(model)
def _load_GNN_LRP(model): def _load_GNN_LRP(model):
@ -39,48 +44,60 @@ def _load_GuidedBackPropagation(model, criterion):
def _load_IntegratedGradients(model, criterion): def _load_IntegratedGradients(model, criterion):
return lambda model: IntegratedGradExplainer(model, criterion) return IntegratedGradExplainer(model, criterion)
def _load_GradExplainer(model, criterion): def _load_GradExplainer(model, criterion):
return lambda model: GradExplainer(model, criterion) return GradExplainer(model, criterion)
def _load_PGExplainer(model, explain_graph=None, in_channels=None): def _load_PGExplainer(model, explain_graph=None, in_channels=None):
return lambda model: PGExplainer( return PGExplainer(model, explain_graph=explain_graph, in_channels=in_channels)
model, explain_graph=explain_graph, in_channels=in_channels
)
def _load_PGMExplainer(model, explain_graph=None): def _load_PGMExplainer(model, explain_graph=None):
return lambda model: PGMExplainer(model, explain_graph) return PGMExplainer(model, explain_graph)
def _load_RandomExplainer(model): def _load_RandomExplainer(model):
return lambda model: RandomExplainer(model) return RandomExplainer(model)
def _load_SubgraphX(model): def _load_SubgraphX(model):
return lambda model: SubgraphX(model) return SubgraphX(model)
def _load_GNNExplainer(model): def _load_GNNExplainer(model):
return lambda model: GNNExplainer(model) return GNNExplainer(model)
def _load_GraphLIME(model): def _load_GraphLIME(model):
return lambda model: GraphLIME(model) return GraphLIME(model)
class GraphXAIWrapper(ExplainerAlgorithm): class GraphXAIWrapper(ExplainerAlgorithm):
def __init__(self, name, criterion=None, in_channels=None): def __init__(self, name, **kwargs):
super().__init__() super().__init__()
self.name = name self.name = name
self.criterion = criterion self.criterion = self._determine_criterion(kwargs["criterion"])
self.explain_graph = ( self.in_channels = self._determine_in_channels(kwargs["in_channels"])
True if self.model_config.task_level == ModelTaskLevel.graph else False
) def _determine_criterion(self, criterion):
self.in_channels = in_channels if criterion == "mse":
loss = MSELoss()
return loss
elif criterion == "cross-entropy":
loss = CrossEntropyLoss()
return loss
else:
raise ValueError(f"{criterion} criterion is not implemented")
def _determine_in_channels(self, in_channels):
if self.name == "PGExplainer":
in_channels = 2 * in_channels
return in_channels
else:
return in_channels
def supports(self) -> bool: def supports(self) -> bool:
task_level = self.model_config.task_level task_level = self.model_config.task_level
@ -122,6 +139,10 @@ class GraphXAIWrapper(ExplainerAlgorithm):
if self.name == "GraphLIME" and task_level == ModelTaskLevel.graph: if self.name == "GraphLIME" and task_level == ModelTaskLevel.graph:
return False return False
self.explain_graph = (
True if self.model_config.task_level == ModelTaskLevel.graph else False
)
return True return True
def _get_mask_type(self): def _get_mask_type(self):
@ -158,22 +179,24 @@ class GraphXAIWrapper(ExplainerAlgorithm):
return _load_GNN_LRP(model) return _load_GNN_LRP(model)
elif self.name == "GradExplainer": elif self.name == "GradExplainer":
return _load_GradExplainer(model,self.criterion) return _load_GradExplainer(model, self.criterion)
elif self.name == "GraphLIME": elif self.name == "GraphLIME":
return _load_GraphLIME(model) return _load_GraphLIME(model)
elif self.name == "GuidedBackPropagation": elif self.name == "GuidedBackPropagation":
return _load_GuidedBackPropagation(model,self.criterion) return _load_GuidedBackPropagation(model, self.criterion)
elif self.name == "IntegratedGradients": elif self.name == "IntegratedGradients":
return _load_IntegratedGradients(model,self.criterion) return _load_IntegratedGradients(model, self.criterion)
elif self.name == "PGExplainer": elif self.name == "PGExplainer":
return _load_PGExplainer(model,explain_graph=self.explain_graph,in_channels=self.in_channels) return _load_PGExplainer(
model, explain_graph=self.explain_graph, in_channels=self.in_channels
)
elif self.name == "PGMExplainer": elif self.name == "PGMExplainer":
return _load_PGMExplainer(model,explain_graph=self.explain_graph) return _load_PGMExplainer(model, explain_graph=self.explain_graph)
elif self.name == "RandomExplainer": elif self.name == "RandomExplainer":
return _load_RandomExplainer(model) return _load_RandomExplainer(model)
@ -181,45 +204,89 @@ class GraphXAIWrapper(ExplainerAlgorithm):
elif self.name == "SubgraphX": elif self.name == "SubgraphX":
return _load_SubgraphX(model) return _load_SubgraphX(model)
else: else:
raise ValueError(f"{self.name} is not a supported Captum method yet !") raise ValueError(f"{self.name} is not supported yet !")
def _parse_attr(self, attr): def _parse_attr(self, attr):
if self.mask_type == "node": node_mask, node_feat_mask, edge_mask, edge_feat_mask, = (
node_mask = attr[0].squeeze(0) None,
edge_mask = None None,
None,
None,
)
for k, v in attr.__dict__.items():
if k == "feature_imp":
node_feat_mask = v
if self.mask_type == "edge": elif k == "node_imp":
node_mask = None node_mask = v
edge_mask = attr[0]
if self.mask_type == "node_and_edge": elif k == "edge_imp":
node_mask = attr[0].squeeze(0) edge_mask = v
edge_mask = attr[1]
else:
raise ValueError
edge_feat_mask = None
node_feat_mask = None
return node_mask, edge_mask, node_feat_mask, edge_feat_mask return node_mask, edge_mask, node_feat_mask, edge_feat_mask
def _parse_method_args(self, method, **kwargs):
signature = inspect.signature(method)
args = tuple(
[
kwargs[k.name]
for k in signature.parameters.values()
if k.name in kwargs.keys()
]
)
return args
def forward( def forward(
self, self,
model: torch.nn.Module, model: torch.nn.Module,
x: Tensor, x: Tensor,
edge_index: Tensor, edge_index: Tensor,
target: int, target: Tensor,
index: Optional[Union[int, Tensor]] = None,
target_index: Optional[int] = None,
**kwargs, **kwargs,
): ):
mask_type = self._get_mask_type() mask_type = self._get_mask_type()
self.graphxai_method = self._load_graphxai_method(model) self.graphxai_method = self._load_graphxai_method(model)
if self.explain_graph: # IF CRITERION = MSE:
attr = self.graphxai_method.get_explanation_graph( # if (
attr = self.captum_method.attribute( # self.name in ["IntegratedGradients", "GradExplainer"]
inputs=inputs, # and "label" in kwargs.keys()
additional_forward_args=additional_forward_args, # ):
target=target, # kwargs["label"] = kwargs["label"].float()
if (
self.name in ["PGMExplainer", "RandomExplainer"]
and "label" in kwargs.keys()
):
kwargs.pop("label")
if self.model_config.task_level == ModelTaskLevel.node:
args = self._parse_method_args(
self.graphxai_method.get_explanation_node,
x=x,
edge_index=edge_index,
node_idx=target,
) )
attr = self.graphxai_method.get_explanation_node(*args, **kwargs)
elif self.model_config.task_level == ModelTaskLevel.graph:
args = self._parse_method_args(
self.graphxai_method.get_explanation_graph,
x=x,
edge_index=edge_index,
)
attr = self.graphxai_method.get_explanation_graph(*args, **kwargs)
elif self.model_config.task_level == ModelTaskLevel.edge:
args = self._parse_method_args(
self.graphxai_method.get_explanation_link,
x=x,
edge_index=edge_index,
)
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
else:
raise ValueError(f"{self.model_config.task_level} is not supported yet")
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr) node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
return Explanation( return Explanation(
x=x, x=x,