Fixing import bug
This commit is contained in:
parent
d1c2565003
commit
7da28de955
|
@ -0,0 +1,536 @@
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from math import sqrt
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import CrossEntropyLoss, MSELoss, ReLU, Sequential
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
|
||||||
|
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
||||||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
||||||
|
from torch_geometric.explain.config import (ExplainerConfig, ExplanationType,
|
||||||
|
MaskType, ModelConfig, ModelMode,
|
||||||
|
ModelTaskLevel)
|
||||||
|
from torch_geometric.nn import Linear
|
||||||
|
from torch_geometric.nn.inits import reset
|
||||||
|
|
||||||
|
|
||||||
|
def get_message_passing_embeddings(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Tensor]:
|
||||||
|
"""Returns the output embeddings of all
|
||||||
|
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
|
||||||
|
:obj:`model`.
|
||||||
|
|
||||||
|
Internally, this method registers forward hooks on all
|
||||||
|
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
|
||||||
|
and runs the forward pass of the :obj:`model` by calling
|
||||||
|
:obj:`model(*args, **kwargs)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The message passing model.
|
||||||
|
*args: Arguments passed to the model.
|
||||||
|
**kwargs (optional): Additional keyword arguments passed to the model.
|
||||||
|
"""
|
||||||
|
from torch_geometric.nn import MessagePassing
|
||||||
|
|
||||||
|
embeddings: List[Tensor] = []
|
||||||
|
|
||||||
|
def hook(model: torch.nn.Module, inputs: Any, outputs: Any):
|
||||||
|
# Clone output in case it will be later modified in-place:
|
||||||
|
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
|
||||||
|
assert isinstance(outputs, Tensor)
|
||||||
|
embeddings.append(outputs.clone())
|
||||||
|
|
||||||
|
hook_handles = []
|
||||||
|
for module in model.modules(): # Register forward hooks:
|
||||||
|
if isinstance(module, MessagePassing):
|
||||||
|
hook_handles.append(module.register_forward_hook(hook))
|
||||||
|
|
||||||
|
if len(hook_handles) == 0:
|
||||||
|
warnings.warn("The 'model' does not have any 'MessagePassing' layers")
|
||||||
|
|
||||||
|
training = model.training
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
model(*args, **kwargs)
|
||||||
|
model.train(training)
|
||||||
|
|
||||||
|
for handle in hook_handles: # Remove hooks:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class PGExplainer(ExplainerAlgorithm):
|
||||||
|
r"""The PGExplainer model from the `"Parameterized Explainer for Graph
|
||||||
|
Neural Network" <https://arxiv.org/abs/2011.04573>`_ paper.
|
||||||
|
Internally, it utilizes a neural network to identify subgraph structures
|
||||||
|
that play a crucial role in the predictions made by a GNN.
|
||||||
|
Importantly, the :class:`PGExplainer` needs to be trained via
|
||||||
|
:meth:`~PGExplainer.train` before being able to generate explanations:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
explainer = Explainer(
|
||||||
|
model=model,
|
||||||
|
algorithm=PGExplainer(epochs=30, lr=0.003),
|
||||||
|
explanation_type='phenomenon',
|
||||||
|
edge_mask_type='object',
|
||||||
|
model_config=ModelConfig(...),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train against a variety of node-level or graph-level predictions:
|
||||||
|
for epoch in range(30):
|
||||||
|
for index in [...]: # Indices to train against.
|
||||||
|
loss = explainer.algorithm.train(epoch, model, x, edge_index,
|
||||||
|
target=target, index=index)
|
||||||
|
|
||||||
|
# Get the final explanations:
|
||||||
|
explanation = explainer(x, edge_index, target=target, index=0)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epochs (int): The number of epochs to train.
|
||||||
|
lr (float, optional): The learning rate to apply.
|
||||||
|
(default: :obj:`0.003`).
|
||||||
|
**kwargs (optional): Additional hyper-parameters to override default
|
||||||
|
settings in
|
||||||
|
:attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
coeffs = {
|
||||||
|
"edge_size": 0.05,
|
||||||
|
"edge_ent": 1.0,
|
||||||
|
"temp": [5.0, 2.0],
|
||||||
|
"bias": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.epochs = epochs
|
||||||
|
self.lr = lr
|
||||||
|
self.coeffs.update(kwargs)
|
||||||
|
|
||||||
|
self.mlp = Sequential(
|
||||||
|
Linear(-1, 64),
|
||||||
|
ReLU(),
|
||||||
|
Linear(64, 1),
|
||||||
|
)
|
||||||
|
self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
|
||||||
|
self._curr_epoch = -1
|
||||||
|
|
||||||
|
def _get_hard_masks(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
index: Optional[Union[int, Tensor]],
|
||||||
|
edge_index: Tensor,
|
||||||
|
num_nodes: int,
|
||||||
|
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
||||||
|
r"""Returns hard node and edge masks that only include the nodes and
|
||||||
|
edges visited during message passing."""
|
||||||
|
if index is None:
|
||||||
|
return None, None # Consider all nodes and edges.
|
||||||
|
|
||||||
|
index, _, _, edge_mask = k_hop_subgraph(
|
||||||
|
index,
|
||||||
|
num_hops=self._num_hops(model),
|
||||||
|
edge_index=edge_index,
|
||||||
|
num_nodes=num_nodes,
|
||||||
|
flow=self._flow(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
|
||||||
|
node_mask[index] = True
|
||||||
|
|
||||||
|
return node_mask, edge_mask
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
reset(self.mlp)
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
epoch: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
*,
|
||||||
|
target: Tensor,
|
||||||
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""Trains the underlying explainer model.
|
||||||
|
Needs to be called before being able to make predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): The current epoch of the training phase.
|
||||||
|
model (torch.nn.Module): The model to explain.
|
||||||
|
x (torch.Tensor): The input node features of a
|
||||||
|
homogeneous graph.
|
||||||
|
edge_index (torch.Tensor): The input edge indices of a homogeneous
|
||||||
|
graph.
|
||||||
|
target (torch.Tensor): The target of the model.
|
||||||
|
index (int or torch.Tensor, optional): The index of the model
|
||||||
|
output to explain. Needs to be a single index.
|
||||||
|
(default: :obj:`None`)
|
||||||
|
**kwargs (optional): Additional keyword arguments passed to
|
||||||
|
:obj:`model`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1]
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
temperature = self._get_temperature(epoch)
|
||||||
|
|
||||||
|
inputs = self._get_inputs(z, edge_index, index)
|
||||||
|
logits = self.mlp(inputs).view(-1)
|
||||||
|
edge_mask = self._concrete_sample(logits, temperature)
|
||||||
|
set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
|
||||||
|
|
||||||
|
if self.model_config.task_level == ModelTaskLevel.node:
|
||||||
|
_, hard_edge_mask = self._get_hard_masks(
|
||||||
|
model, index, edge_index, num_nodes=x.size(0)
|
||||||
|
)
|
||||||
|
edge_mask = edge_mask[hard_edge_mask]
|
||||||
|
|
||||||
|
y_hat, y = model(x, edge_index, **kwargs), target
|
||||||
|
|
||||||
|
if index is not None:
|
||||||
|
y_hat, y = y_hat[index], y[index]
|
||||||
|
|
||||||
|
loss = self._loss(y_hat, y, edge_mask)
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
clear_masks(model)
|
||||||
|
self._curr_epoch = epoch
|
||||||
|
|
||||||
|
return float(loss)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
*,
|
||||||
|
target: Tensor,
|
||||||
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Explanation:
|
||||||
|
if self._curr_epoch < self.epochs - 1: # Safety check:
|
||||||
|
raise ValueError(
|
||||||
|
f"'{self.__class__.__name__}' is not yet fully "
|
||||||
|
f"trained (got {self._curr_epoch + 1} epochs "
|
||||||
|
f"from {self.epochs} epochs). Please first train "
|
||||||
|
f"the underlying explainer model by running "
|
||||||
|
f"`explainer.algorithm.train(...)`."
|
||||||
|
)
|
||||||
|
|
||||||
|
hard_edge_mask = None
|
||||||
|
_, hard_edge_mask = self._get_hard_masks(
|
||||||
|
model, index, edge_index, num_nodes=x.size(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
loss = self.train(epoch, model, x, edge_index, target=target, index=index)
|
||||||
|
z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1]
|
||||||
|
|
||||||
|
inputs = self._get_inputs(z, edge_index, index)
|
||||||
|
logits = self.mlp(inputs).view(-1)
|
||||||
|
|
||||||
|
edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True)
|
||||||
|
|
||||||
|
return Explanation(edge_mask=edge_mask)
|
||||||
|
|
||||||
|
###########################################################################
|
||||||
|
|
||||||
|
def _get_inputs(
|
||||||
|
self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None
|
||||||
|
) -> Tensor:
|
||||||
|
zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
|
||||||
|
if self.model_config.task_level == ModelTaskLevel.node:
|
||||||
|
assert index is not None
|
||||||
|
zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
|
||||||
|
return torch.cat(zs, dim=-1)
|
||||||
|
|
||||||
|
def _get_temperature(self, epoch: int) -> float:
|
||||||
|
temp = self.coeffs["temp"]
|
||||||
|
return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
|
||||||
|
|
||||||
|
def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor:
|
||||||
|
bias = self.coeffs["bias"]
|
||||||
|
eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
|
||||||
|
return (eps.log() - (1 - eps).log() + logits) / temperature
|
||||||
|
|
||||||
|
def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor:
|
||||||
|
if self.model_config.mode == ModelMode.binary_classification:
|
||||||
|
loss = self._loss_binary_classification(y_hat, y)
|
||||||
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
||||||
|
loss = self._loss_multiclass_classification(y_hat, y)
|
||||||
|
elif self.model_config.mode == ModelMode.regression:
|
||||||
|
loss = self._loss_regression(y_hat, y)
|
||||||
|
|
||||||
|
# Regularization loss:
|
||||||
|
mask = edge_mask.sigmoid()
|
||||||
|
size_loss = mask.sum() * self.coeffs["edge_size"]
|
||||||
|
mask = 0.99 * mask + 0.005
|
||||||
|
mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log()
|
||||||
|
mask_ent_loss = mask_ent.mean() * self.coeffs["edge_ent"]
|
||||||
|
|
||||||
|
return loss + size_loss + mask_ent_loss
|
||||||
|
|
||||||
|
|
||||||
|
class GNNExplainer(ExplainerAlgorithm):
|
||||||
|
r"""The GNN-Explainer model from the `"GNNExplainer: Generating
|
||||||
|
Explanations for Graph Neural Networks"
|
||||||
|
<https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
|
||||||
|
structures and node features that play a crucial role in the predictions
|
||||||
|
made by a GNN.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
For an example of using :class:`GNNExplainer`, see
|
||||||
|
`examples/gnn_explainer.py <https://github.com/pyg-team/
|
||||||
|
pytorch_geometric/blob/master/examples/gnn_explainer.py>`_ and
|
||||||
|
`examples/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/
|
||||||
|
pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epochs (int, optional): The number of epochs to train.
|
||||||
|
(default: :obj:`100`)
|
||||||
|
lr (float, optional): The learning rate to apply.
|
||||||
|
(default: :obj:`0.01`)
|
||||||
|
**kwargs (optional): Additional hyper-parameters to override default
|
||||||
|
settings in
|
||||||
|
:attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
coeffs = {
|
||||||
|
"edge_size": 0.005,
|
||||||
|
"edge_reduction": "sum",
|
||||||
|
"node_feat_size": 1.0,
|
||||||
|
"node_feat_reduction": "mean",
|
||||||
|
"edge_ent": 1.0,
|
||||||
|
"node_feat_ent": 0.1,
|
||||||
|
"EPS": 1e-15,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.epochs = epochs
|
||||||
|
self.lr = lr
|
||||||
|
self.coeffs.update(kwargs)
|
||||||
|
|
||||||
|
self.node_mask = self.edge_mask = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
*,
|
||||||
|
target: Tensor,
|
||||||
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Explanation:
|
||||||
|
|
||||||
|
hard_node_mask = hard_edge_mask = None
|
||||||
|
hard_node_mask, hard_edge_mask = self._get_hard_masks(
|
||||||
|
model, index, edge_index, num_nodes=x.size(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._train(model, x, edge_index, target=target, index=index, **kwargs)
|
||||||
|
|
||||||
|
node_mask = self._post_process_mask(
|
||||||
|
self.node_mask, hard_node_mask, apply_sigmoid=True
|
||||||
|
)
|
||||||
|
edge_mask = self._post_process_mask(
|
||||||
|
self.edge_mask, hard_edge_mask, apply_sigmoid=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return Explanation(node_mask=node_mask, edge_mask=edge_mask)
|
||||||
|
|
||||||
|
def _train(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
*,
|
||||||
|
target: Tensor,
|
||||||
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self._initialize_masks(x, edge_index)
|
||||||
|
|
||||||
|
parameters = [self.node_mask] # We always learn a node mask.
|
||||||
|
if self.explainer_config.edge_mask_type is not None:
|
||||||
|
set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
|
||||||
|
parameters.append(self.edge_mask)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(parameters, lr=self.lr)
|
||||||
|
|
||||||
|
for _ in range(self.epochs):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
h = x * self.node_mask.sigmoid()
|
||||||
|
y_hat, y = model(h, edge_index, **kwargs), target
|
||||||
|
|
||||||
|
if index is not None:
|
||||||
|
y_hat, y = y_hat[index], y[index]
|
||||||
|
|
||||||
|
loss = self._loss(y_hat, y)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
def _initialize_masks(self, x: Tensor, edge_index: Tensor):
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
|
||||||
|
device = x.device
|
||||||
|
(N, F), E = x.size(), edge_index.size(1)
|
||||||
|
|
||||||
|
std = 0.1
|
||||||
|
if node_mask_type == MaskType.object:
|
||||||
|
self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
|
||||||
|
elif node_mask_type == MaskType.attributes:
|
||||||
|
self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
|
||||||
|
elif node_mask_type == MaskType.common_attributes:
|
||||||
|
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
if edge_mask_type == MaskType.object:
|
||||||
|
std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N))
|
||||||
|
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
|
||||||
|
elif edge_mask_type is not None:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
||||||
|
if self.model_config.mode == ModelMode.binary_classification:
|
||||||
|
loss = self._loss_binary_classification(y_hat, y)
|
||||||
|
elif self.model_config.mode == ModelMode.multiclass_classification:
|
||||||
|
loss = self._loss_multiclass_classification(y_hat, y)
|
||||||
|
elif self.model_config.mode == ModelMode.regression:
|
||||||
|
loss = self._loss_regression(y_hat, y)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
if self.explainer_config.edge_mask_type is not None:
|
||||||
|
m = self.edge_mask.sigmoid()
|
||||||
|
edge_reduce = getattr(torch, self.coeffs["edge_reduction"])
|
||||||
|
loss = loss + self.coeffs["edge_size"] * edge_reduce(m)
|
||||||
|
ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log(
|
||||||
|
1 - m + self.coeffs["EPS"]
|
||||||
|
)
|
||||||
|
loss = loss + self.coeffs["edge_ent"] * ent.mean()
|
||||||
|
|
||||||
|
m = self.node_mask.sigmoid()
|
||||||
|
node_feat_reduce = getattr(torch, self.coeffs["node_feat_reduction"])
|
||||||
|
loss = loss + self.coeffs["node_feat_size"] * node_feat_reduce(m)
|
||||||
|
ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log(
|
||||||
|
1 - m + self.coeffs["EPS"]
|
||||||
|
)
|
||||||
|
loss = loss + self.coeffs["node_feat_ent"] * ent.mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _clean_model(self, model):
|
||||||
|
clear_masks(model)
|
||||||
|
self.node_mask = None
|
||||||
|
self.edge_mask = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_GNNExplainer():
|
||||||
|
return GNNExplainer()
|
||||||
|
|
||||||
|
|
||||||
|
def load_PGExplainer():
|
||||||
|
return PGExplainer(epochs=30)
|
||||||
|
|
||||||
|
|
||||||
|
class PYGWrapper(ExplainerAlgorithm):
|
||||||
|
def __init__(self, name, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def supports(self) -> bool:
|
||||||
|
task_level = self.model_config.task_level
|
||||||
|
if self.name == "PGExplainer":
|
||||||
|
explanation_type = self.explainer_config.explanation_type
|
||||||
|
if explanation_type != ExplanationType.phenomenon:
|
||||||
|
logging.error(
|
||||||
|
f"'{self.__class__.__name__}' only supports "
|
||||||
|
f"phenomenon explanations "
|
||||||
|
f"got (`explanation_type={explanation_type.value}`)"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
task_level = self.model_config.task_level
|
||||||
|
if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}:
|
||||||
|
logging.error(
|
||||||
|
f"'{self.__class__.__name__}' only supports "
|
||||||
|
f"node-level or graph-level explanations "
|
||||||
|
f"got (`task_level={task_level.value}`)"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
if node_mask_type is not None:
|
||||||
|
logging.error(
|
||||||
|
f"'{self.__class__.__name__}' does not support "
|
||||||
|
f"explaining input node features "
|
||||||
|
f"got (`node_mask_type={node_mask_type.value}`)"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
elif self.name == "GNNExplainer":
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _load_pyg_method(self):
|
||||||
|
if self.name == "GNNExplainer":
|
||||||
|
return load_GNNExplainer()
|
||||||
|
if self.name == "PGExplainer":
|
||||||
|
return load_PGExplainer()
|
||||||
|
|
||||||
|
def _parse_attr(self, attr):
|
||||||
|
data = attr.to_dict()
|
||||||
|
node_mask = data.get("node_mask")
|
||||||
|
edge_mask = data.get("edge_mask")
|
||||||
|
node_feat_mask = data.get("node_feat_mask")
|
||||||
|
edge_feat_mask = data.get("edge_feat_mask")
|
||||||
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
target: Tensor,
|
||||||
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.pyg_method = self._load_pyg_method()
|
||||||
|
attr = self.pyg_method.forward(
|
||||||
|
model=model,
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
index=index,
|
||||||
|
target=target,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
||||||
|
return Explanation(
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
y=target,
|
||||||
|
edge_mask=edge_mask,
|
||||||
|
node_mask=node_mask,
|
||||||
|
node_feat_mask=node_feat_mask,
|
||||||
|
edge_feat_mask=edge_feat_mask,
|
||||||
|
)
|
Loading…
Reference in New Issue