explaining_framework/explaining_framework/explainers/wrappers/from_graphxai.py
2022-12-14 18:28:12 +01:00

268 lines
8.4 KiB
Python

import inspect
from typing import Dict, Optional, Tuple, Union
import torch
from graphxai.explainers.cam import CAM, GradCAM
from graphxai.explainers.gnn_explainer import GNNExplainer
from graphxai.explainers.gnn_lrp import GNN_LRP
from graphxai.explainers.grad import GradExplainer
from graphxai.explainers.graphlime import GraphLIME
from graphxai.explainers.guidedbp import GuidedBP
from graphxai.explainers.integrated_grad import IntegratedGradExplainer
from graphxai.explainers.pg_explainer import PGExplainer
from graphxai.explainers.pgm_explainer import PGMExplainer
from graphxai.explainers.random import RandomExplainer
from graphxai.explainers.subgraphx import SubgraphX
from torch import Tensor
from torch.nn import CrossEntropyLoss, MSELoss
from torch_geometric.data import Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
def _load_CAM(model):
return CAM(model)
def _load_GradCAM(model):
return GradCAM(model)
def _load_GNN_LRP(model):
# return lambda model: GNN_LRP(model)
raise ValueError("GraphXAI GNN_LRP is not supported yet")
def _load_GuidedBackPropagation(model, criterion):
# return lambda model: GuidedBP(model, criterion)
raise ValueError(
"GraphXAI GuidedBackPropagation is discarded since already available in Captum for Pytorch Geometric (see CaptumWrapper)"
)
def _load_IntegratedGradients(model, criterion):
return IntegratedGradExplainer(model, criterion)
def _load_GradExplainer(model, criterion):
return GradExplainer(model, criterion)
def _load_PGExplainer(model, explain_graph=None, in_channels=None):
return PGExplainer(model, explain_graph=explain_graph, in_channels=in_channels)
def _load_PGMExplainer(model, explain_graph=None):
return PGMExplainer(model, explain_graph)
def _load_RandomExplainer(model):
return RandomExplainer(model)
def _load_SubgraphX(model):
return SubgraphX(model)
def _load_GNNExplainer(model):
return GNNExplainer(model)
def _load_GraphLIME(model):
return GraphLIME(model)
class GraphXAIWrapper(ExplainerAlgorithm):
def __init__(self, name, **kwargs):
super().__init__()
self.name = name
self.criterion = self._determine_criterion(kwargs["criterion"])
self.in_channels = self._determine_in_channels(kwargs["in_channels"])
def _determine_criterion(self, criterion):
if criterion == "mse":
loss = MSELoss()
return loss
elif criterion == "cross-entropy":
loss = CrossEntropyLoss()
return loss
else:
raise ValueError(f"{criterion} criterion is not implemented")
def _determine_in_channels(self, in_channels):
if self.name == "PGExplainer":
in_channels = 2 * in_channels
return in_channels
else:
return in_channels
def supports(self) -> bool:
task_level = self.model_config.task_level
if self.name in [
"CAM",
"GradCAM",
"GNN_LRP",
"GradExplainer",
"GuidedBP",
"IntegratedGradExplainer",
"PGExplainer",
"PGMExplainer",
"RandomExplainer",
"SubgraphX",
"GNNExplainer",
]:
if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:
logging.error(f"Task level '{task_level.value}' not supported")
return False
edge_mask_type = self.explainer_config.edge_mask_type
if edge_mask_type not in [MaskType.object, None]:
logging.error(
f"Edge mask type '{edge_mask_type.value}' not " f"supported"
)
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [
MaskType.common_attributes,
MaskType.object,
MaskType.attributes,
None,
]:
logging.error(
f"Node mask type '{node_mask_type.value}' not " f"supported."
)
return False
if self.name == "GraphLIME" and task_level == ModelTaskLevel.graph:
return False
self.explain_graph = (
True if self.model_config.task_level == ModelTaskLevel.graph else False
)
return True
def _get_mask_type(self):
edge_mask_type = self.explainer_config.edge_mask_type
node_mask_type = self.explainer_config.node_mask_type
if edge_mask_type is None and node_mask_type is None:
raise ValueError("You need to provide a masking config")
if not edge_mask_type is None and node_mask_type is None:
self.mask_type = "edge"
if edge_mask_type is None and not node_mask_type is None:
self.mask_type = "node"
if not edge_mask_type is None and not node_mask_type is None:
self.mask_type = "node_and_edge"
return self.mask_type
def _load_graphxai_method(self, model):
if self.name == "CAM":
return _load_CAM(model)
elif self.name == "GradCAM":
return _load_GradCAM(model)
elif self.name == "GNNExplainer":
return _load_GNNExplainer(model)
elif self.name == "GNN_LRP":
return _load_GNN_LRP(model)
elif self.name == "GradExplainer":
return _load_GradExplainer(model, self.criterion)
elif self.name == "GraphLIME":
return _load_GraphLIME(model)
elif self.name == "GuidedBackPropagation":
return _load_GuidedBackPropagation(model, self.criterion)
elif self.name == "IntegratedGradients":
return _load_IntegratedGradients(model, self.criterion)
elif self.name == "PGExplainer":
return _load_PGExplainer(
model, explain_graph=self.explain_graph, in_channels=self.in_channels
)
elif self.name == "PGMExplainer":
return _load_PGMExplainer(model, explain_graph=self.explain_graph)
elif self.name == "RandomExplainer":
return _load_RandomExplainer(model)
elif self.name == "SubgraphX":
return _load_SubgraphX(model)
else:
raise ValueError(f"{self.name} is not supported yet !")
def _parse_attr(self, attr):
node_mask, node_feat_mask, edge_mask, edge_feat_mask, = (
None,
None,
None,
None,
)
for k, v in attr.__dict__.items():
if k == "feature_imp":
node_feat_mask = v
elif k == "node_imp":
node_mask = v
elif k == "edge_imp":
edge_mask = v
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
):
mask_type = self._get_mask_type()
self.graphxai_method = self._load_graphxai_method(model)
if self.model_config.task_level == ModelTaskLevel.node:
attr = self.graphxai_method.get_explanation_node(
x=x,
edge_index=edge_index,
label=target,
node_idx=index,
y=target,
)
elif self.model_config.task_level == ModelTaskLevel.graph:
attr = self.graphxai_method.get_explanation_graph(
x=x,
edge_index=edge_index,
label=target,
y=target,
)
elif self.model_config.task_level == ModelTaskLevel.edge:
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
else:
raise ValueError(f"{self.model_config.task_level} is not supported yet")
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
return Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)