New fixes
This commit is contained in:
parent
8067185d1a
commit
6cf1d64d3a
|
@ -1,3 +1,9 @@
|
||||||
|
import torch
|
||||||
|
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||||
|
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||||
|
GuidedGradCam, InputXGradient, IntegratedGradients,
|
||||||
|
Lime, Occlusion, Saliency)
|
||||||
|
from torch import Tensor
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.explain import Explanation
|
from torch_geometric.explain import Explanation
|
||||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
@ -7,14 +13,9 @@ from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
|
from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type,
|
||||||
to_captum_input, to_captum_model)
|
to_captum_input, to_captum_model)
|
||||||
|
|
||||||
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
|
||||||
FeaturePermutation, GradientShap, GuidedBackprop,
|
|
||||||
GuidedGradCam, InputXGradient, IntegratedGradients,
|
|
||||||
Lime, Occlusion, Saliency)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_FeatureAblation(model):
|
def _load_FeatureAblation(model):
|
||||||
return lambda model: FeatureAblation(model)
|
return FeatureAblation(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_LRP(model):
|
def _load_LRP(model):
|
||||||
|
@ -43,7 +44,7 @@ def _load_GradientShap(model):
|
||||||
|
|
||||||
|
|
||||||
def _load_GuidedBackPropagation(model):
|
def _load_GuidedBackPropagation(model):
|
||||||
return lambda model: GuidedBackprop(model)
|
return GuidedBackprop(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_GuidedGradCam(model):
|
def _load_GuidedGradCam(model):
|
||||||
|
@ -52,15 +53,15 @@ def _load_GuidedGradCam(model):
|
||||||
|
|
||||||
|
|
||||||
def _load_InputXGradient(model):
|
def _load_InputXGradient(model):
|
||||||
return lambda model: InputXGradient(model)
|
return InputXGradient(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_Lime(model):
|
def _load_Lime(model):
|
||||||
return lambda model: Lime(model)
|
return Lime(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_Saliency(model):
|
def _load_Saliency(model):
|
||||||
return lambda model: Saliency(model)
|
return Saliency(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_Occlusion(model):
|
def _load_Occlusion(model):
|
||||||
|
@ -212,7 +213,7 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
mask_type = self._get_mask_type()
|
mask_type = self._get_mask_type()
|
||||||
converted_model = to_captum_model(model, mask_type=mask_type, output=target)
|
converted_model = to_captum_model(model, mask_type=mask_type, output_idx=target)
|
||||||
self.captum_method = self._load_captum_method(converted_model)
|
self.captum_method = self._load_captum_method(converted_model)
|
||||||
inputs, additional_forward_args = to_captum_input(
|
inputs, additional_forward_args = to_captum_input(
|
||||||
x, edge_index, mask_type=mask_type
|
x, edge_index, mask_type=mask_type
|
|
@ -1,10 +1,7 @@
|
||||||
from torch_geometric.data import Data
|
import inspect
|
||||||
from torch_geometric.explain import Explanation
|
from typing import Dict, Optional, Tuple, Union
|
||||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
|
||||||
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
|
||||||
ModelConfig, ModelMode,
|
|
||||||
ModelTaskLevel)
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from graphxai.explainers.cam import CAM, GradCAM
|
from graphxai.explainers.cam import CAM, GradCAM
|
||||||
from graphxai.explainers.gnn_explainer import GNNExplainer
|
from graphxai.explainers.gnn_explainer import GNNExplainer
|
||||||
from graphxai.explainers.gnn_lrp import GNN_LRP
|
from graphxai.explainers.gnn_lrp import GNN_LRP
|
||||||
|
@ -16,14 +13,22 @@ from graphxai.explainers.pg_explainer import PGExplainer
|
||||||
from graphxai.explainers.pgm_explainer import PGMExplainer
|
from graphxai.explainers.pgm_explainer import PGMExplainer
|
||||||
from graphxai.explainers.random import RandomExplainer
|
from graphxai.explainers.random import RandomExplainer
|
||||||
from graphxai.explainers.subgraphx import SubgraphX
|
from graphxai.explainers.subgraphx import SubgraphX
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import CrossEntropyLoss, KLDivLoss, MSELoss
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.explain import Explanation
|
||||||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
|
ModelConfig, ModelMode,
|
||||||
|
ModelTaskLevel)
|
||||||
|
|
||||||
|
|
||||||
def _load_CAM(model):
|
def _load_CAM(model):
|
||||||
return lambda model: CAM(model)
|
return CAM(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_GradCAM(model):
|
def _load_GradCAM(model):
|
||||||
return lambda model: GradCAM(model)
|
return GradCAM(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_GNN_LRP(model):
|
def _load_GNN_LRP(model):
|
||||||
|
@ -39,48 +44,60 @@ def _load_GuidedBackPropagation(model, criterion):
|
||||||
|
|
||||||
|
|
||||||
def _load_IntegratedGradients(model, criterion):
|
def _load_IntegratedGradients(model, criterion):
|
||||||
return lambda model: IntegratedGradExplainer(model, criterion)
|
return IntegratedGradExplainer(model, criterion)
|
||||||
|
|
||||||
|
|
||||||
def _load_GradExplainer(model, criterion):
|
def _load_GradExplainer(model, criterion):
|
||||||
return lambda model: GradExplainer(model, criterion)
|
return GradExplainer(model, criterion)
|
||||||
|
|
||||||
|
|
||||||
def _load_PGExplainer(model, explain_graph=None, in_channels=None):
|
def _load_PGExplainer(model, explain_graph=None, in_channels=None):
|
||||||
return lambda model: PGExplainer(
|
return PGExplainer(model, explain_graph=explain_graph, in_channels=in_channels)
|
||||||
model, explain_graph=explain_graph, in_channels=in_channels
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_PGMExplainer(model, explain_graph=None):
|
def _load_PGMExplainer(model, explain_graph=None):
|
||||||
return lambda model: PGMExplainer(model, explain_graph)
|
return PGMExplainer(model, explain_graph)
|
||||||
|
|
||||||
|
|
||||||
def _load_RandomExplainer(model):
|
def _load_RandomExplainer(model):
|
||||||
return lambda model: RandomExplainer(model)
|
return RandomExplainer(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_SubgraphX(model):
|
def _load_SubgraphX(model):
|
||||||
return lambda model: SubgraphX(model)
|
return SubgraphX(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_GNNExplainer(model):
|
def _load_GNNExplainer(model):
|
||||||
return lambda model: GNNExplainer(model)
|
return GNNExplainer(model)
|
||||||
|
|
||||||
|
|
||||||
def _load_GraphLIME(model):
|
def _load_GraphLIME(model):
|
||||||
return lambda model: GraphLIME(model)
|
return GraphLIME(model)
|
||||||
|
|
||||||
|
|
||||||
class GraphXAIWrapper(ExplainerAlgorithm):
|
class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
def __init__(self, name, criterion=None, in_channels=None):
|
def __init__(self, name, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.criterion = criterion
|
self.criterion = self._determine_criterion(kwargs["criterion"])
|
||||||
self.explain_graph = (
|
self.in_channels = self._determine_in_channels(kwargs["in_channels"])
|
||||||
True if self.model_config.task_level == ModelTaskLevel.graph else False
|
|
||||||
)
|
def _determine_criterion(self, criterion):
|
||||||
self.in_channels = in_channels
|
if criterion == "mse":
|
||||||
|
loss = MSELoss()
|
||||||
|
return loss
|
||||||
|
elif criterion == "cross-entropy":
|
||||||
|
loss = CrossEntropyLoss()
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{criterion} criterion is not implemented")
|
||||||
|
|
||||||
|
def _determine_in_channels(self, in_channels):
|
||||||
|
if self.name == "PGExplainer":
|
||||||
|
in_channels = 2 * in_channels
|
||||||
|
return in_channels
|
||||||
|
else:
|
||||||
|
return in_channels
|
||||||
|
|
||||||
def supports(self) -> bool:
|
def supports(self) -> bool:
|
||||||
task_level = self.model_config.task_level
|
task_level = self.model_config.task_level
|
||||||
|
@ -122,6 +139,10 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
if self.name == "GraphLIME" and task_level == ModelTaskLevel.graph:
|
if self.name == "GraphLIME" and task_level == ModelTaskLevel.graph:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
self.explain_graph = (
|
||||||
|
True if self.model_config.task_level == ModelTaskLevel.graph else False
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_mask_type(self):
|
def _get_mask_type(self):
|
||||||
|
@ -158,22 +179,24 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
return _load_GNN_LRP(model)
|
return _load_GNN_LRP(model)
|
||||||
|
|
||||||
elif self.name == "GradExplainer":
|
elif self.name == "GradExplainer":
|
||||||
return _load_GradExplainer(model,self.criterion)
|
return _load_GradExplainer(model, self.criterion)
|
||||||
|
|
||||||
elif self.name == "GraphLIME":
|
elif self.name == "GraphLIME":
|
||||||
return _load_GraphLIME(model)
|
return _load_GraphLIME(model)
|
||||||
|
|
||||||
elif self.name == "GuidedBackPropagation":
|
elif self.name == "GuidedBackPropagation":
|
||||||
return _load_GuidedBackPropagation(model,self.criterion)
|
return _load_GuidedBackPropagation(model, self.criterion)
|
||||||
|
|
||||||
elif self.name == "IntegratedGradients":
|
elif self.name == "IntegratedGradients":
|
||||||
return _load_IntegratedGradients(model,self.criterion)
|
return _load_IntegratedGradients(model, self.criterion)
|
||||||
|
|
||||||
elif self.name == "PGExplainer":
|
elif self.name == "PGExplainer":
|
||||||
return _load_PGExplainer(model,explain_graph=self.explain_graph,in_channels=self.in_channels)
|
return _load_PGExplainer(
|
||||||
|
model, explain_graph=self.explain_graph, in_channels=self.in_channels
|
||||||
|
)
|
||||||
|
|
||||||
elif self.name == "PGMExplainer":
|
elif self.name == "PGMExplainer":
|
||||||
return _load_PGMExplainer(model,explain_graph=self.explain_graph)
|
return _load_PGMExplainer(model, explain_graph=self.explain_graph)
|
||||||
|
|
||||||
elif self.name == "RandomExplainer":
|
elif self.name == "RandomExplainer":
|
||||||
return _load_RandomExplainer(model)
|
return _load_RandomExplainer(model)
|
||||||
|
@ -181,45 +204,89 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
elif self.name == "SubgraphX":
|
elif self.name == "SubgraphX":
|
||||||
return _load_SubgraphX(model)
|
return _load_SubgraphX(model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
raise ValueError(f"{self.name} is not supported yet !")
|
||||||
|
|
||||||
def _parse_attr(self, attr):
|
def _parse_attr(self, attr):
|
||||||
if self.mask_type == "node":
|
node_mask, node_feat_mask, edge_mask, edge_feat_mask, = (
|
||||||
node_mask = attr[0].squeeze(0)
|
None,
|
||||||
edge_mask = None
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
for k, v in attr.__dict__.items():
|
||||||
|
if k == "feature_imp":
|
||||||
|
node_feat_mask = v
|
||||||
|
|
||||||
if self.mask_type == "edge":
|
elif k == "node_imp":
|
||||||
node_mask = None
|
node_mask = v
|
||||||
edge_mask = attr[0]
|
|
||||||
|
|
||||||
if self.mask_type == "node_and_edge":
|
elif k == "edge_imp":
|
||||||
node_mask = attr[0].squeeze(0)
|
edge_mask = v
|
||||||
edge_mask = attr[1]
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
edge_feat_mask = None
|
|
||||||
node_feat_mask = None
|
|
||||||
|
|
||||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
|
def _parse_method_args(self, method, **kwargs):
|
||||||
|
signature = inspect.signature(method)
|
||||||
|
args = tuple(
|
||||||
|
[
|
||||||
|
kwargs[k.name]
|
||||||
|
for k in signature.parameters.values()
|
||||||
|
if k.name in kwargs.keys()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
edge_index: Tensor,
|
edge_index: Tensor,
|
||||||
target: int,
|
target: Tensor,
|
||||||
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
|
target_index: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
mask_type = self._get_mask_type()
|
mask_type = self._get_mask_type()
|
||||||
self.graphxai_method = self._load_graphxai_method(model)
|
self.graphxai_method = self._load_graphxai_method(model)
|
||||||
if self.explain_graph:
|
# IF CRITERION = MSE:
|
||||||
attr = self.graphxai_method.get_explanation_graph(
|
# if (
|
||||||
attr = self.captum_method.attribute(
|
# self.name in ["IntegratedGradients", "GradExplainer"]
|
||||||
inputs=inputs,
|
# and "label" in kwargs.keys()
|
||||||
additional_forward_args=additional_forward_args,
|
# ):
|
||||||
target=target,
|
# kwargs["label"] = kwargs["label"].float()
|
||||||
|
if (
|
||||||
|
self.name in ["PGMExplainer", "RandomExplainer"]
|
||||||
|
and "label" in kwargs.keys()
|
||||||
|
):
|
||||||
|
kwargs.pop("label")
|
||||||
|
|
||||||
|
if self.model_config.task_level == ModelTaskLevel.node:
|
||||||
|
args = self._parse_method_args(
|
||||||
|
self.graphxai_method.get_explanation_node,
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
node_idx=target,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
attr = self.graphxai_method.get_explanation_node(*args, **kwargs)
|
||||||
|
elif self.model_config.task_level == ModelTaskLevel.graph:
|
||||||
|
args = self._parse_method_args(
|
||||||
|
self.graphxai_method.get_explanation_graph,
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
attr = self.graphxai_method.get_explanation_graph(*args, **kwargs)
|
||||||
|
elif self.model_config.task_level == ModelTaskLevel.edge:
|
||||||
|
args = self._parse_method_args(
|
||||||
|
self.graphxai_method.get_explanation_link,
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.model_config.task_level} is not supported yet")
|
||||||
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
||||||
return Explanation(
|
return Explanation(
|
||||||
x=x,
|
x=x,
|
Loading…
Reference in New Issue