New fixes and starting to develop GraphXAI wrapper
This commit is contained in:
parent
2aacd9fddd
commit
8067185d1a
2 changed files with 256 additions and 15 deletions
|
@ -185,19 +185,23 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
|
|
||||||
def _parse_attr(self, attr):
|
def _parse_attr(self, attr):
|
||||||
if self.mask_type == "node":
|
if self.mask_type == "node":
|
||||||
node_mask = attr
|
node_mask = attr[0].squeeze(0)
|
||||||
edge_mask = None
|
edge_mask = None
|
||||||
|
|
||||||
if "edge" == mask_type:
|
if self.mask_type == "edge":
|
||||||
node_mask = None
|
node_mask = None
|
||||||
edge_mask = attr
|
edge_mask = attr[0]
|
||||||
|
|
||||||
if "node_and_mask" == mask_type:
|
if self.mask_type == "node_and_edge":
|
||||||
node_mask = attr[0]
|
node_mask = attr[0].squeeze(0)
|
||||||
edge_mask = attr[1]
|
edge_mask = attr[1]
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
# TODO
|
edge_feat_mask = None
|
||||||
pass
|
node_feat_mask = None
|
||||||
|
|
||||||
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -213,19 +217,25 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
inputs, additional_forward_args = to_captum_input(
|
inputs, additional_forward_args = to_captum_input(
|
||||||
x, edge_index, mask_type=mask_type
|
x, edge_index, mask_type=mask_type
|
||||||
)
|
)
|
||||||
if self.name in ["InputXGradient", "Lime", "Saliency", "GuidedBackPropagation"]:
|
if self.name in [
|
||||||
|
"InputXGradient",
|
||||||
|
"Lime",
|
||||||
|
"Saliency",
|
||||||
|
"GuidedBackPropagation",
|
||||||
|
"FeatureAblation",
|
||||||
|
]:
|
||||||
attr = self.captum_method.attribute(
|
attr = self.captum_method.attribute(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
additional_forward_args=additional_forward_args,
|
additional_forward_args=additional_forward_args,
|
||||||
target=target,
|
target=target,
|
||||||
)
|
)
|
||||||
elif self.name == "FeatureAblation":
|
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
||||||
attr = self.captum_method.attribute(
|
|
||||||
inputs=inputs,
|
|
||||||
additional_forward_args=additional_forward_args,
|
|
||||||
)
|
|
||||||
node_mask, edge_mask = self._parse_attr(attr)
|
|
||||||
|
|
||||||
return Explanation(
|
return Explanation(
|
||||||
x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
edge_mask=edge_mask,
|
||||||
|
node_mask=node_mask,
|
||||||
|
node_feat_mask=node_feat_mask,
|
||||||
|
edge_feat_mask=edge_feat_mask,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,231 @@
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.explain import Explanation
|
||||||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
|
ModelConfig, ModelMode,
|
||||||
|
ModelTaskLevel)
|
||||||
|
|
||||||
|
from graphxai.explainers.cam import CAM, GradCAM
|
||||||
|
from graphxai.explainers.gnn_explainer import GNNExplainer
|
||||||
|
from graphxai.explainers.gnn_lrp import GNN_LRP
|
||||||
|
from graphxai.explainers.grad import GradExplainer
|
||||||
|
from graphxai.explainers.graphlime import GraphLIME
|
||||||
|
from graphxai.explainers.guidedbp import GuidedBP
|
||||||
|
from graphxai.explainers.integrated_grad import IntegratedGradExplainer
|
||||||
|
from graphxai.explainers.pg_explainer import PGExplainer
|
||||||
|
from graphxai.explainers.pgm_explainer import PGMExplainer
|
||||||
|
from graphxai.explainers.random import RandomExplainer
|
||||||
|
from graphxai.explainers.subgraphx import SubgraphX
|
||||||
|
|
||||||
|
|
||||||
|
def _load_CAM(model):
|
||||||
|
return lambda model: CAM(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GradCAM(model):
|
||||||
|
return lambda model: GradCAM(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GNN_LRP(model):
|
||||||
|
# return lambda model: GNN_LRP(model)
|
||||||
|
raise ValueError("GraphXAI GNN_LRP is not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GuidedBackPropagation(model, criterion):
|
||||||
|
# return lambda model: GuidedBP(model, criterion)
|
||||||
|
raise ValueError(
|
||||||
|
"GraphXAI GuidedBackPropagation is discarded since already available in Captum for Pytorch Geometric (see CaptumWrapper)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_IntegratedGradients(model, criterion):
|
||||||
|
return lambda model: IntegratedGradExplainer(model, criterion)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GradExplainer(model, criterion):
|
||||||
|
return lambda model: GradExplainer(model, criterion)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_PGExplainer(model, explain_graph=None, in_channels=None):
|
||||||
|
return lambda model: PGExplainer(
|
||||||
|
model, explain_graph=explain_graph, in_channels=in_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_PGMExplainer(model, explain_graph=None):
|
||||||
|
return lambda model: PGMExplainer(model, explain_graph)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_RandomExplainer(model):
|
||||||
|
return lambda model: RandomExplainer(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_SubgraphX(model):
|
||||||
|
return lambda model: SubgraphX(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GNNExplainer(model):
|
||||||
|
return lambda model: GNNExplainer(model)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_GraphLIME(model):
|
||||||
|
return lambda model: GraphLIME(model)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
|
def __init__(self, name, criterion=None, in_channels=None):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
self.criterion = criterion
|
||||||
|
self.explain_graph = (
|
||||||
|
True if self.model_config.task_level == ModelTaskLevel.graph else False
|
||||||
|
)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
def supports(self) -> bool:
|
||||||
|
task_level = self.model_config.task_level
|
||||||
|
if self.name in [
|
||||||
|
"CAM",
|
||||||
|
"GradCAM",
|
||||||
|
"GNN_LRP",
|
||||||
|
"GradExplainer",
|
||||||
|
"GuidedBP",
|
||||||
|
"IntegratedGradExplainer",
|
||||||
|
"PGExplainer",
|
||||||
|
"PGMExplainer",
|
||||||
|
"RandomExplainer",
|
||||||
|
"SubgraphX",
|
||||||
|
"GNNExplainer",
|
||||||
|
]:
|
||||||
|
if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:
|
||||||
|
logging.error(f"Task level '{task_level.value}' not supported")
|
||||||
|
return False
|
||||||
|
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
if edge_mask_type not in [MaskType.object, None]:
|
||||||
|
logging.error(
|
||||||
|
f"Edge mask type '{edge_mask_type.value}' not " f"supported"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
if node_mask_type not in [
|
||||||
|
MaskType.common_attributes,
|
||||||
|
MaskType.object,
|
||||||
|
MaskType.attributes,
|
||||||
|
None,
|
||||||
|
]:
|
||||||
|
logging.error(
|
||||||
|
f"Node mask type '{node_mask_type.value}' not " f"supported."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if self.name == "GraphLIME" and task_level == ModelTaskLevel.graph:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_mask_type(self):
|
||||||
|
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
|
||||||
|
if edge_mask_type is None and node_mask_type is None:
|
||||||
|
raise ValueError("You need to provide a masking config")
|
||||||
|
|
||||||
|
if not edge_mask_type is None and node_mask_type is None:
|
||||||
|
self.mask_type = "edge"
|
||||||
|
|
||||||
|
if edge_mask_type is None and not node_mask_type is None:
|
||||||
|
self.mask_type = "node"
|
||||||
|
|
||||||
|
if not edge_mask_type is None and not node_mask_type is None:
|
||||||
|
self.mask_type = "node_and_edge"
|
||||||
|
|
||||||
|
return self.mask_type
|
||||||
|
|
||||||
|
def _load_graphxai_method(self, model):
|
||||||
|
|
||||||
|
if self.name == "CAM":
|
||||||
|
return _load_CAM(model)
|
||||||
|
|
||||||
|
elif self.name == "GradCAM":
|
||||||
|
return _load_GradCAM(model)
|
||||||
|
|
||||||
|
elif self.name == "GNNExplainer":
|
||||||
|
return _load_GNNExplainer(model)
|
||||||
|
|
||||||
|
elif self.name == "GNN_LRP":
|
||||||
|
return _load_GNN_LRP(model)
|
||||||
|
|
||||||
|
elif self.name == "GradExplainer":
|
||||||
|
return _load_GradExplainer(model,self.criterion)
|
||||||
|
|
||||||
|
elif self.name == "GraphLIME":
|
||||||
|
return _load_GraphLIME(model)
|
||||||
|
|
||||||
|
elif self.name == "GuidedBackPropagation":
|
||||||
|
return _load_GuidedBackPropagation(model,self.criterion)
|
||||||
|
|
||||||
|
elif self.name == "IntegratedGradients":
|
||||||
|
return _load_IntegratedGradients(model,self.criterion)
|
||||||
|
|
||||||
|
elif self.name == "PGExplainer":
|
||||||
|
return _load_PGExplainer(model,explain_graph=self.explain_graph,in_channels=self.in_channels)
|
||||||
|
|
||||||
|
elif self.name == "PGMExplainer":
|
||||||
|
return _load_PGMExplainer(model,explain_graph=self.explain_graph)
|
||||||
|
|
||||||
|
elif self.name == "RandomExplainer":
|
||||||
|
return _load_RandomExplainer(model)
|
||||||
|
|
||||||
|
elif self.name == "SubgraphX":
|
||||||
|
return _load_SubgraphX(model)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.name} is not a supported Captum method yet !")
|
||||||
|
|
||||||
|
def _parse_attr(self, attr):
|
||||||
|
if self.mask_type == "node":
|
||||||
|
node_mask = attr[0].squeeze(0)
|
||||||
|
edge_mask = None
|
||||||
|
|
||||||
|
if self.mask_type == "edge":
|
||||||
|
node_mask = None
|
||||||
|
edge_mask = attr[0]
|
||||||
|
|
||||||
|
if self.mask_type == "node_and_edge":
|
||||||
|
node_mask = attr[0].squeeze(0)
|
||||||
|
edge_mask = attr[1]
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
edge_feat_mask = None
|
||||||
|
node_feat_mask = None
|
||||||
|
|
||||||
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
target: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
mask_type = self._get_mask_type()
|
||||||
|
self.graphxai_method = self._load_graphxai_method(model)
|
||||||
|
if self.explain_graph:
|
||||||
|
attr = self.graphxai_method.get_explanation_graph(
|
||||||
|
attr = self.captum_method.attribute(
|
||||||
|
inputs=inputs,
|
||||||
|
additional_forward_args=additional_forward_args,
|
||||||
|
target=target,
|
||||||
|
)
|
||||||
|
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
||||||
|
return Explanation(
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
edge_mask=edge_mask,
|
||||||
|
node_mask=node_mask,
|
||||||
|
node_feat_mask=node_feat_mask,
|
||||||
|
edge_feat_mask=edge_feat_mask,
|
||||||
|
)
|
Loading…
Add table
Reference in a new issue