Fixing import bug
This commit is contained in:
parent
d1c2565003
commit
7da28de955
536
explaining_framework/explainers/wrappers/from_pyg.py
Normal file
536
explaining_framework/explainers/wrappers/from_pyg.py
Normal file
@ -0,0 +1,536 @@
|
||||
import logging
|
||||
import warnings
|
||||
from math import sqrt
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, MSELoss, ReLU, Sequential
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
|
||||
from torch_geometric.explain.algorithm import ExplainerAlgorithm
|
||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
||||
from torch_geometric.explain.config import (ExplainerConfig, ExplanationType,
|
||||
MaskType, ModelConfig, ModelMode,
|
||||
ModelTaskLevel)
|
||||
from torch_geometric.nn import Linear
|
||||
from torch_geometric.nn.inits import reset
|
||||
|
||||
|
||||
def get_message_passing_embeddings(
|
||||
model: torch.nn.Module,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[Tensor]:
|
||||
"""Returns the output embeddings of all
|
||||
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
|
||||
:obj:`model`.
|
||||
|
||||
Internally, this method registers forward hooks on all
|
||||
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
|
||||
and runs the forward pass of the :obj:`model` by calling
|
||||
:obj:`model(*args, **kwargs)`.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The message passing model.
|
||||
*args: Arguments passed to the model.
|
||||
**kwargs (optional): Additional keyword arguments passed to the model.
|
||||
"""
|
||||
from torch_geometric.nn import MessagePassing
|
||||
|
||||
embeddings: List[Tensor] = []
|
||||
|
||||
def hook(model: torch.nn.Module, inputs: Any, outputs: Any):
|
||||
# Clone output in case it will be later modified in-place:
|
||||
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
|
||||
assert isinstance(outputs, Tensor)
|
||||
embeddings.append(outputs.clone())
|
||||
|
||||
hook_handles = []
|
||||
for module in model.modules(): # Register forward hooks:
|
||||
if isinstance(module, MessagePassing):
|
||||
hook_handles.append(module.register_forward_hook(hook))
|
||||
|
||||
if len(hook_handles) == 0:
|
||||
warnings.warn("The 'model' does not have any 'MessagePassing' layers")
|
||||
|
||||
training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model(*args, **kwargs)
|
||||
model.train(training)
|
||||
|
||||
for handle in hook_handles: # Remove hooks:
|
||||
handle.remove()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class PGExplainer(ExplainerAlgorithm):
|
||||
r"""The PGExplainer model from the `"Parameterized Explainer for Graph
|
||||
Neural Network" <https://arxiv.org/abs/2011.04573>`_ paper.
|
||||
Internally, it utilizes a neural network to identify subgraph structures
|
||||
that play a crucial role in the predictions made by a GNN.
|
||||
Importantly, the :class:`PGExplainer` needs to be trained via
|
||||
:meth:`~PGExplainer.train` before being able to generate explanations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
explainer = Explainer(
|
||||
model=model,
|
||||
algorithm=PGExplainer(epochs=30, lr=0.003),
|
||||
explanation_type='phenomenon',
|
||||
edge_mask_type='object',
|
||||
model_config=ModelConfig(...),
|
||||
)
|
||||
|
||||
# Train against a variety of node-level or graph-level predictions:
|
||||
for epoch in range(30):
|
||||
for index in [...]: # Indices to train against.
|
||||
loss = explainer.algorithm.train(epoch, model, x, edge_index,
|
||||
target=target, index=index)
|
||||
|
||||
# Get the final explanations:
|
||||
explanation = explainer(x, edge_index, target=target, index=0)
|
||||
|
||||
Args:
|
||||
epochs (int): The number of epochs to train.
|
||||
lr (float, optional): The learning rate to apply.
|
||||
(default: :obj:`0.003`).
|
||||
**kwargs (optional): Additional hyper-parameters to override default
|
||||
settings in
|
||||
:attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`.
|
||||
"""
|
||||
|
||||
coeffs = {
|
||||
"edge_size": 0.05,
|
||||
"edge_ent": 1.0,
|
||||
"temp": [5.0, 2.0],
|
||||
"bias": 0.0,
|
||||
}
|
||||
|
||||
def __init__(self, epochs: int, lr: float = 0.003, **kwargs):
|
||||
super().__init__()
|
||||
self.epochs = epochs
|
||||
self.lr = lr
|
||||
self.coeffs.update(kwargs)
|
||||
|
||||
self.mlp = Sequential(
|
||||
Linear(-1, 64),
|
||||
ReLU(),
|
||||
Linear(64, 1),
|
||||
)
|
||||
self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)
|
||||
self._curr_epoch = -1
|
||||
|
||||
def _get_hard_masks(
|
||||
model: torch.nn.Module,
|
||||
index: Optional[Union[int, Tensor]],
|
||||
edge_index: Tensor,
|
||||
num_nodes: int,
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor]]:
|
||||
r"""Returns hard node and edge masks that only include the nodes and
|
||||
edges visited during message passing."""
|
||||
if index is None:
|
||||
return None, None # Consider all nodes and edges.
|
||||
|
||||
index, _, _, edge_mask = k_hop_subgraph(
|
||||
index,
|
||||
num_hops=self._num_hops(model),
|
||||
edge_index=edge_index,
|
||||
num_nodes=num_nodes,
|
||||
flow=self._flow(model),
|
||||
)
|
||||
|
||||
node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)
|
||||
node_mask[index] = True
|
||||
|
||||
return node_mask, edge_mask
|
||||
|
||||
def reset_parameters(self):
|
||||
reset(self.mlp)
|
||||
|
||||
def train(
|
||||
self,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
*,
|
||||
target: Tensor,
|
||||
index: Optional[Union[int, Tensor]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""Trains the underlying explainer model.
|
||||
Needs to be called before being able to make predictions.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch of the training phase.
|
||||
model (torch.nn.Module): The model to explain.
|
||||
x (torch.Tensor): The input node features of a
|
||||
homogeneous graph.
|
||||
edge_index (torch.Tensor): The input edge indices of a homogeneous
|
||||
graph.
|
||||
target (torch.Tensor): The target of the model.
|
||||
index (int or torch.Tensor, optional): The index of the model
|
||||
output to explain. Needs to be a single index.
|
||||
(default: :obj:`None`)
|
||||
**kwargs (optional): Additional keyword arguments passed to
|
||||
:obj:`model`.
|
||||
"""
|
||||
|
||||
z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1]
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
temperature = self._get_temperature(epoch)
|
||||
|
||||
inputs = self._get_inputs(z, edge_index, index)
|
||||
logits = self.mlp(inputs).view(-1)
|
||||
edge_mask = self._concrete_sample(logits, temperature)
|
||||
set_masks(model, edge_mask, edge_index, apply_sigmoid=True)
|
||||
|
||||
if self.model_config.task_level == ModelTaskLevel.node:
|
||||
_, hard_edge_mask = self._get_hard_masks(
|
||||
model, index, edge_index, num_nodes=x.size(0)
|
||||
)
|
||||
edge_mask = edge_mask[hard_edge_mask]
|
||||
|
||||
y_hat, y = model(x, edge_index, **kwargs), target
|
||||
|
||||
if index is not None:
|
||||
y_hat, y = y_hat[index], y[index]
|
||||
|
||||
loss = self._loss(y_hat, y, edge_mask)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
clear_masks(model)
|
||||
self._curr_epoch = epoch
|
||||
|
||||
return float(loss)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
*,
|
||||
target: Tensor,
|
||||
index: Optional[Union[int, Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Explanation:
|
||||
if self._curr_epoch < self.epochs - 1: # Safety check:
|
||||
raise ValueError(
|
||||
f"'{self.__class__.__name__}' is not yet fully "
|
||||
f"trained (got {self._curr_epoch + 1} epochs "
|
||||
f"from {self.epochs} epochs). Please first train "
|
||||
f"the underlying explainer model by running "
|
||||
f"`explainer.algorithm.train(...)`."
|
||||
)
|
||||
|
||||
hard_edge_mask = None
|
||||
_, hard_edge_mask = self._get_hard_masks(
|
||||
model, index, edge_index, num_nodes=x.size(0)
|
||||
)
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
loss = self.train(epoch, model, x, edge_index, target=target, index=index)
|
||||
z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1]
|
||||
|
||||
inputs = self._get_inputs(z, edge_index, index)
|
||||
logits = self.mlp(inputs).view(-1)
|
||||
|
||||
edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True)
|
||||
|
||||
return Explanation(edge_mask=edge_mask)
|
||||
|
||||
###########################################################################
|
||||
|
||||
def _get_inputs(
|
||||
self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None
|
||||
) -> Tensor:
|
||||
zs = [embedding[edge_index[0]], embedding[edge_index[1]]]
|
||||
if self.model_config.task_level == ModelTaskLevel.node:
|
||||
assert index is not None
|
||||
zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))
|
||||
return torch.cat(zs, dim=-1)
|
||||
|
||||
def _get_temperature(self, epoch: int) -> float:
|
||||
temp = self.coeffs["temp"]
|
||||
return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
|
||||
|
||||
def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor:
|
||||
bias = self.coeffs["bias"]
|
||||
eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
|
||||
return (eps.log() - (1 - eps).log() + logits) / temperature
|
||||
|
||||
def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor:
|
||||
if self.model_config.mode == ModelMode.binary_classification:
|
||||
loss = self._loss_binary_classification(y_hat, y)
|
||||
elif self.model_config.mode == ModelMode.multiclass_classification:
|
||||
loss = self._loss_multiclass_classification(y_hat, y)
|
||||
elif self.model_config.mode == ModelMode.regression:
|
||||
loss = self._loss_regression(y_hat, y)
|
||||
|
||||
# Regularization loss:
|
||||
mask = edge_mask.sigmoid()
|
||||
size_loss = mask.sum() * self.coeffs["edge_size"]
|
||||
mask = 0.99 * mask + 0.005
|
||||
mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log()
|
||||
mask_ent_loss = mask_ent.mean() * self.coeffs["edge_ent"]
|
||||
|
||||
return loss + size_loss + mask_ent_loss
|
||||
|
||||
|
||||
class GNNExplainer(ExplainerAlgorithm):
|
||||
r"""The GNN-Explainer model from the `"GNNExplainer: Generating
|
||||
Explanations for Graph Neural Networks"
|
||||
<https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
|
||||
structures and node features that play a crucial role in the predictions
|
||||
made by a GNN.
|
||||
|
||||
.. note::
|
||||
|
||||
For an example of using :class:`GNNExplainer`, see
|
||||
`examples/gnn_explainer.py <https://github.com/pyg-team/
|
||||
pytorch_geometric/blob/master/examples/gnn_explainer.py>`_ and
|
||||
`examples/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/
|
||||
pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py>`_.
|
||||
|
||||
Args:
|
||||
epochs (int, optional): The number of epochs to train.
|
||||
(default: :obj:`100`)
|
||||
lr (float, optional): The learning rate to apply.
|
||||
(default: :obj:`0.01`)
|
||||
**kwargs (optional): Additional hyper-parameters to override default
|
||||
settings in
|
||||
:attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
|
||||
"""
|
||||
|
||||
coeffs = {
|
||||
"edge_size": 0.005,
|
||||
"edge_reduction": "sum",
|
||||
"node_feat_size": 1.0,
|
||||
"node_feat_reduction": "mean",
|
||||
"edge_ent": 1.0,
|
||||
"node_feat_ent": 0.1,
|
||||
"EPS": 1e-15,
|
||||
}
|
||||
|
||||
def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
|
||||
super().__init__()
|
||||
self.epochs = epochs
|
||||
self.lr = lr
|
||||
self.coeffs.update(kwargs)
|
||||
|
||||
self.node_mask = self.edge_mask = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
*,
|
||||
target: Tensor,
|
||||
index: Optional[Union[int, Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Explanation:
|
||||
|
||||
hard_node_mask = hard_edge_mask = None
|
||||
hard_node_mask, hard_edge_mask = self._get_hard_masks(
|
||||
model, index, edge_index, num_nodes=x.size(0)
|
||||
)
|
||||
|
||||
self._train(model, x, edge_index, target=target, index=index, **kwargs)
|
||||
|
||||
node_mask = self._post_process_mask(
|
||||
self.node_mask, hard_node_mask, apply_sigmoid=True
|
||||
)
|
||||
edge_mask = self._post_process_mask(
|
||||
self.edge_mask, hard_edge_mask, apply_sigmoid=True
|
||||
)
|
||||
|
||||
return Explanation(node_mask=node_mask, edge_mask=edge_mask)
|
||||
|
||||
def _train(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
*,
|
||||
target: Tensor,
|
||||
index: Optional[Union[int, Tensor]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._initialize_masks(x, edge_index)
|
||||
|
||||
parameters = [self.node_mask] # We always learn a node mask.
|
||||
if self.explainer_config.edge_mask_type is not None:
|
||||
set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
|
||||
parameters.append(self.edge_mask)
|
||||
|
||||
optimizer = torch.optim.Adam(parameters, lr=self.lr)
|
||||
|
||||
for _ in range(self.epochs):
|
||||
optimizer.zero_grad()
|
||||
|
||||
h = x * self.node_mask.sigmoid()
|
||||
y_hat, y = model(h, edge_index, **kwargs), target
|
||||
|
||||
if index is not None:
|
||||
y_hat, y = y_hat[index], y[index]
|
||||
|
||||
loss = self._loss(y_hat, y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def _initialize_masks(self, x: Tensor, edge_index: Tensor):
|
||||
node_mask_type = self.explainer_config.node_mask_type
|
||||
edge_mask_type = self.explainer_config.edge_mask_type
|
||||
|
||||
device = x.device
|
||||
(N, F), E = x.size(), edge_index.size(1)
|
||||
|
||||
std = 0.1
|
||||
if node_mask_type == MaskType.object:
|
||||
self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
|
||||
elif node_mask_type == MaskType.attributes:
|
||||
self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
|
||||
elif node_mask_type == MaskType.common_attributes:
|
||||
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if edge_mask_type == MaskType.object:
|
||||
std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N))
|
||||
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
|
||||
elif edge_mask_type is not None:
|
||||
assert False
|
||||
|
||||
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
|
||||
if self.model_config.mode == ModelMode.binary_classification:
|
||||
loss = self._loss_binary_classification(y_hat, y)
|
||||
elif self.model_config.mode == ModelMode.multiclass_classification:
|
||||
loss = self._loss_multiclass_classification(y_hat, y)
|
||||
elif self.model_config.mode == ModelMode.regression:
|
||||
loss = self._loss_regression(y_hat, y)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if self.explainer_config.edge_mask_type is not None:
|
||||
m = self.edge_mask.sigmoid()
|
||||
edge_reduce = getattr(torch, self.coeffs["edge_reduction"])
|
||||
loss = loss + self.coeffs["edge_size"] * edge_reduce(m)
|
||||
ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log(
|
||||
1 - m + self.coeffs["EPS"]
|
||||
)
|
||||
loss = loss + self.coeffs["edge_ent"] * ent.mean()
|
||||
|
||||
m = self.node_mask.sigmoid()
|
||||
node_feat_reduce = getattr(torch, self.coeffs["node_feat_reduction"])
|
||||
loss = loss + self.coeffs["node_feat_size"] * node_feat_reduce(m)
|
||||
ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log(
|
||||
1 - m + self.coeffs["EPS"]
|
||||
)
|
||||
loss = loss + self.coeffs["node_feat_ent"] * ent.mean()
|
||||
|
||||
return loss
|
||||
|
||||
def _clean_model(self, model):
|
||||
clear_masks(model)
|
||||
self.node_mask = None
|
||||
self.edge_mask = None
|
||||
|
||||
|
||||
def load_GNNExplainer():
|
||||
return GNNExplainer()
|
||||
|
||||
|
||||
def load_PGExplainer():
|
||||
return PGExplainer(epochs=30)
|
||||
|
||||
|
||||
class PYGWrapper(ExplainerAlgorithm):
|
||||
def __init__(self, name, **kwargs):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
def supports(self) -> bool:
|
||||
task_level = self.model_config.task_level
|
||||
if self.name == "PGExplainer":
|
||||
explanation_type = self.explainer_config.explanation_type
|
||||
if explanation_type != ExplanationType.phenomenon:
|
||||
logging.error(
|
||||
f"'{self.__class__.__name__}' only supports "
|
||||
f"phenomenon explanations "
|
||||
f"got (`explanation_type={explanation_type.value}`)"
|
||||
)
|
||||
return False
|
||||
|
||||
task_level = self.model_config.task_level
|
||||
if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}:
|
||||
logging.error(
|
||||
f"'{self.__class__.__name__}' only supports "
|
||||
f"node-level or graph-level explanations "
|
||||
f"got (`task_level={task_level.value}`)"
|
||||
)
|
||||
return False
|
||||
|
||||
node_mask_type = self.explainer_config.node_mask_type
|
||||
if node_mask_type is not None:
|
||||
logging.error(
|
||||
f"'{self.__class__.__name__}' does not support "
|
||||
f"explaining input node features "
|
||||
f"got (`node_mask_type={node_mask_type.value}`)"
|
||||
)
|
||||
return False
|
||||
elif self.name == "GNNExplainer":
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def _load_pyg_method(self):
|
||||
if self.name == "GNNExplainer":
|
||||
return load_GNNExplainer()
|
||||
if self.name == "PGExplainer":
|
||||
return load_PGExplainer()
|
||||
|
||||
def _parse_attr(self, attr):
|
||||
data = attr.to_dict()
|
||||
node_mask = data.get("node_mask")
|
||||
edge_mask = data.get("edge_mask")
|
||||
node_feat_mask = data.get("node_feat_mask")
|
||||
edge_feat_mask = data.get("edge_feat_mask")
|
||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
target: Tensor,
|
||||
index: Optional[Union[int, Tensor]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.pyg_method = self._load_pyg_method()
|
||||
attr = self.pyg_method.forward(
|
||||
model=model,
|
||||
x=x,
|
||||
edge_index=edge_index,
|
||||
index=index,
|
||||
target=target,
|
||||
)
|
||||
|
||||
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
||||
return Explanation(
|
||||
x=x,
|
||||
edge_index=edge_index,
|
||||
y=target,
|
||||
edge_mask=edge_mask,
|
||||
node_mask=node_mask,
|
||||
node_feat_mask=node_feat_mask,
|
||||
edge_feat_mask=edge_feat_mask,
|
||||
)
|
Loading…
Reference in New Issue
Block a user