Fixing import bug

This commit is contained in:
araison 2023-01-09 19:23:43 +01:00
parent d1c2565003
commit 7da28de955

View File

@ -0,0 +1,536 @@
import logging
import warnings
from math import sqrt
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss, MSELoss, ReLU, Sequential
from torch.nn.parameter import Parameter
from torch_geometric.data import Data
from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (ExplainerConfig, ExplanationType,
MaskType, ModelConfig, ModelMode,
ModelTaskLevel)
from torch_geometric.nn import Linear
from torch_geometric.nn.inits import reset
def get_message_passing_embeddings(
model: torch.nn.Module,
*args,
**kwargs,
) -> List[Tensor]:
"""Returns the output embeddings of all
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
:obj:`model`.
Internally, this method registers forward hooks on all
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
and runs the forward pass of the :obj:`model` by calling
:obj:`model(*args, **kwargs)`.
Args:
model (torch.nn.Module): The message passing model.
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the model.
"""
from torch_geometric.nn import MessagePassing
embeddings: List[Tensor] = []
def hook(model: torch.nn.Module, inputs: Any, outputs: Any):
# Clone output in case it will be later modified in-place:
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
assert isinstance(outputs, Tensor)
embeddings.append(outputs.clone())
hook_handles = []
for module in model.modules(): # Register forward hooks:
if isinstance(module, MessagePassing):
hook_handles.append(module.register_forward_hook(hook))
if len(hook_handles) == 0:
warnings.warn("The 'model' does not have any 'MessagePassing' layers")
training = model.training
model.eval()
with torch.no_grad():
model(*args, **kwargs)
model.train(training)
for handle in hook_handles: # Remove hooks:
handle.remove()
return embeddings
class PGExplainer(ExplainerAlgorithm):
r"""The PGExplainer model from the `"Parameterized Explainer for Graph
Neural Network" <https://arxiv.org/abs/2011.04573>`_ paper.
Internally, it utilizes a neural network to identify subgraph structures
that play a crucial role in the predictions made by a GNN.
Importantly, the :class:`PGExplainer` needs to be trained via
:meth:`~PGExplainer.train` before being able to generate explanations:
.. code-block:: python
explainer = Explainer(
model=model,
algorithm=PGExplainer(epochs=30, lr=0.003),
explanation_type='phenomenon',
edge_mask_type='object',
model_config=ModelConfig(...),
)
# Train against a variety of node-level or graph-level predictions:
for epoch in range(30):
for index in [...]: # Indices to train against.
loss = explainer.algorithm.train(epoch, model, x, edge_index,
target=target, index=index)
# Get the final explanations:
explanation = explainer(x, edge_index, target=target, index=0)
Args:
epochs (int): The number of epochs to train.
lr (float, optional): The learning rate to apply.
(default: :obj:`0.003`).
**kwargs (optional): Additional hyper-parameters to override default
settings in
:attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`.
"""
coeffs = {
"edge_size": 0.05,
"edge_ent": 1.0,
"temp": [5.0, 2.0],
"bias": 0.0,
}
def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
super().__init__()
self.epochs = epochs
self.lr = lr
self.coeffs.update(kwargs)
self.mlp = Sequential(
Linear(-1, 64),
ReLU(),
Linear(64, 1),
)
self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
self._curr_epoch = -1
def _get_hard_masks(
model: torch.nn.Module,
index: Optional[Union[int, Tensor]],
edge_index: Tensor,
num_nodes: int,
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
r"""Returns hard node and edge masks that only include the nodes and
edges visited during message passing."""
if index is None:
return None, None # Consider all nodes and edges.
index, _, _, edge_mask = k_hop_subgraph(
index,
num_hops=self._num_hops(model),
edge_index=edge_index,
num_nodes=num_nodes,
flow=self._flow(model),
)
node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
node_mask[index] = True
return node_mask, edge_mask
def reset_parameters(self):
reset(self.mlp)
def train(
self,
epoch: int,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
):
r"""Trains the underlying explainer model.
Needs to be called before being able to make predictions.
Args:
epoch (int): The current epoch of the training phase.
model (torch.nn.Module): The model to explain.
x (torch.Tensor): The input node features of a
homogeneous graph.
edge_index (torch.Tensor): The input edge indices of a homogeneous
graph.
target (torch.Tensor): The target of the model.
index (int or torch.Tensor, optional): The index of the model
output to explain. Needs to be a single index.
(default: :obj:`None`)
**kwargs (optional): Additional keyword arguments passed to
:obj:`model`.
"""
z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1]
self.optimizer.zero_grad()
temperature = self._get_temperature(epoch)
inputs = self._get_inputs(z, edge_index, index)
logits = self.mlp(inputs).view(-1)
edge_mask = self._concrete_sample(logits, temperature)
set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
if self.model_config.task_level == ModelTaskLevel.node:
_, hard_edge_mask = self._get_hard_masks(
model, index, edge_index, num_nodes=x.size(0)
)
edge_mask = edge_mask[hard_edge_mask]
y_hat, y = model(x, edge_index, **kwargs), target
if index is not None:
y_hat, y = y_hat[index], y[index]
loss = self._loss(y_hat, y, edge_mask)
loss.backward()
self.optimizer.step()
clear_masks(model)
self._curr_epoch = epoch
return float(loss)
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Explanation:
if self._curr_epoch < self.epochs - 1: # Safety check:
raise ValueError(
f"'{self.__class__.__name__}' is not yet fully "
f"trained (got {self._curr_epoch + 1} epochs "
f"from {self.epochs} epochs). Please first train "
f"the underlying explainer model by running "
f"`explainer.algorithm.train(...)`."
)
hard_edge_mask = None
_, hard_edge_mask = self._get_hard_masks(
model, index, edge_index, num_nodes=x.size(0)
)
for epoch in range(self.epochs):
loss = self.train(epoch, model, x, edge_index, target=target, index=index)
z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1]
inputs = self._get_inputs(z, edge_index, index)
logits = self.mlp(inputs).view(-1)
edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True)
return Explanation(edge_mask=edge_mask)
###########################################################################
def _get_inputs(
self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None
) -> Tensor:
zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
if self.model_config.task_level == ModelTaskLevel.node:
assert index is not None
zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
return torch.cat(zs, dim=-1)
def _get_temperature(self, epoch: int) -> float:
temp = self.coeffs["temp"]
return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor:
bias = self.coeffs["bias"]
eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
return (eps.log() - (1 - eps).log() + logits) / temperature
def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor:
if self.model_config.mode == ModelMode.binary_classification:
loss = self._loss_binary_classification(y_hat, y)
elif self.model_config.mode == ModelMode.multiclass_classification:
loss = self._loss_multiclass_classification(y_hat, y)
elif self.model_config.mode == ModelMode.regression:
loss = self._loss_regression(y_hat, y)
# Regularization loss:
mask = edge_mask.sigmoid()
size_loss = mask.sum() * self.coeffs["edge_size"]
mask = 0.99 * mask + 0.005
mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log()
mask_ent_loss = mask_ent.mean() * self.coeffs["edge_ent"]
return loss + size_loss + mask_ent_loss
class GNNExplainer(ExplainerAlgorithm):
r"""The GNN-Explainer model from the `"GNNExplainer: Generating
Explanations for Graph Neural Networks"
<https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
structures and node features that play a crucial role in the predictions
made by a GNN.
.. note::
For an example of using :class:`GNNExplainer`, see
`examples/gnn_explainer.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/gnn_explainer.py>`_ and
`examples/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py>`_.
Args:
epochs (int, optional): The number of epochs to train.
(default: :obj:`100`)
lr (float, optional): The learning rate to apply.
(default: :obj:`0.01`)
**kwargs (optional): Additional hyper-parameters to override default
settings in
:attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
"""
coeffs = {
"edge_size": 0.005,
"edge_reduction": "sum",
"node_feat_size": 1.0,
"node_feat_reduction": "mean",
"edge_ent": 1.0,
"node_feat_ent": 0.1,
"EPS": 1e-15,
}
def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
super().__init__()
self.epochs = epochs
self.lr = lr
self.coeffs.update(kwargs)
self.node_mask = self.edge_mask = None
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Explanation:
hard_node_mask = hard_edge_mask = None
hard_node_mask, hard_edge_mask = self._get_hard_masks(
model, index, edge_index, num_nodes=x.size(0)
)
self._train(model, x, edge_index, target=target, index=index, **kwargs)
node_mask = self._post_process_mask(
self.node_mask, hard_node_mask, apply_sigmoid=True
)
edge_mask = self._post_process_mask(
self.edge_mask, hard_edge_mask, apply_sigmoid=True
)
return Explanation(node_mask=node_mask, edge_mask=edge_mask)
def _train(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
):
self._initialize_masks(x, edge_index)
parameters = [self.node_mask] # We always learn a node mask.
if self.explainer_config.edge_mask_type is not None:
set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
parameters.append(self.edge_mask)
optimizer = torch.optim.Adam(parameters, lr=self.lr)
for _ in range(self.epochs):
optimizer.zero_grad()
h = x * self.node_mask.sigmoid()
y_hat, y = model(h, edge_index, **kwargs), target
if index is not None:
y_hat, y = y_hat[index], y[index]
loss = self._loss(y_hat, y)
loss.backward()
optimizer.step()
def _initialize_masks(self, x: Tensor, edge_index: Tensor):
node_mask_type = self.explainer_config.node_mask_type
edge_mask_type = self.explainer_config.edge_mask_type
device = x.device
(N, F), E = x.size(), edge_index.size(1)
std = 0.1
if node_mask_type == MaskType.object:
self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
elif node_mask_type == MaskType.attributes:
self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
elif node_mask_type == MaskType.common_attributes:
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
else:
assert False
if edge_mask_type == MaskType.object:
std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N))
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
elif edge_mask_type is not None:
assert False
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
if self.model_config.mode == ModelMode.binary_classification:
loss = self._loss_binary_classification(y_hat, y)
elif self.model_config.mode == ModelMode.multiclass_classification:
loss = self._loss_multiclass_classification(y_hat, y)
elif self.model_config.mode == ModelMode.regression:
loss = self._loss_regression(y_hat, y)
else:
assert False
if self.explainer_config.edge_mask_type is not None:
m = self.edge_mask.sigmoid()
edge_reduce = getattr(torch, self.coeffs["edge_reduction"])
loss = loss + self.coeffs["edge_size"] * edge_reduce(m)
ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log(
1 - m + self.coeffs["EPS"]
)
loss = loss + self.coeffs["edge_ent"] * ent.mean()
m = self.node_mask.sigmoid()
node_feat_reduce = getattr(torch, self.coeffs["node_feat_reduction"])
loss = loss + self.coeffs["node_feat_size"] * node_feat_reduce(m)
ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log(
1 - m + self.coeffs["EPS"]
)
loss = loss + self.coeffs["node_feat_ent"] * ent.mean()
return loss
def _clean_model(self, model):
clear_masks(model)
self.node_mask = None
self.edge_mask = None
def load_GNNExplainer():
return GNNExplainer()
def load_PGExplainer():
return PGExplainer(epochs=30)
class PYGWrapper(ExplainerAlgorithm):
def __init__(self, name, **kwargs):
super().__init__()
self.name = name
def supports(self) -> bool:
task_level = self.model_config.task_level
if self.name == "PGExplainer":
explanation_type = self.explainer_config.explanation_type
if explanation_type != ExplanationType.phenomenon:
logging.error(
f"'{self.__class__.__name__}' only supports "
f"phenomenon explanations "
f"got (`explanation_type={explanation_type.value}`)"
)
return False
task_level = self.model_config.task_level
if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}:
logging.error(
f"'{self.__class__.__name__}' only supports "
f"node-level or graph-level explanations "
f"got (`task_level={task_level.value}`)"
)
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type is not None:
logging.error(
f"'{self.__class__.__name__}' does not support "
f"explaining input node features "
f"got (`node_mask_type={node_mask_type.value}`)"
)
return False
elif self.name == "GNNExplainer":
pass
return True
def _load_pyg_method(self):
if self.name == "GNNExplainer":
return load_GNNExplainer()
if self.name == "PGExplainer":
return load_PGExplainer()
def _parse_attr(self, attr):
data = attr.to_dict()
node_mask = data.get("node_mask")
edge_mask = data.get("edge_mask")
node_feat_mask = data.get("node_feat_mask")
edge_feat_mask = data.get("edge_feat_mask")
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
):
self.pyg_method = self._load_pyg_method()
attr = self.pyg_method.forward(
model=model,
x=x,
edge_index=edge_index,
index=index,
target=target,
)
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
return Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)