diff --git a/explaining_framework/explainers/wrappers/from_pyg.py b/explaining_framework/explainers/wrappers/from_pyg.py new file mode 100644 index 0000000..2d6dbdc --- /dev/null +++ b/explaining_framework/explainers/wrappers/from_pyg.py @@ -0,0 +1,536 @@ +import logging +import warnings +from math import sqrt +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, MSELoss, ReLU, Sequential +from torch.nn.parameter import Parameter +from torch_geometric.data import Data +from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig +from torch_geometric.explain.algorithm import ExplainerAlgorithm +from torch_geometric.explain.algorithm.base import ExplainerAlgorithm +from torch_geometric.explain.algorithm.utils import clear_masks, set_masks +from torch_geometric.explain.config import (ExplainerConfig, ExplanationType, + MaskType, ModelConfig, ModelMode, + ModelTaskLevel) +from torch_geometric.nn import Linear +from torch_geometric.nn.inits import reset + + +def get_message_passing_embeddings( + model: torch.nn.Module, + *args, + **kwargs, +) -> List[Tensor]: + """Returns the output embeddings of all + :class:`~torch_geometric.nn.conv.MessagePassing` layers in + :obj:`model`. + + Internally, this method registers forward hooks on all + :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`, + and runs the forward pass of the :obj:`model` by calling + :obj:`model(*args, **kwargs)`. + + Args: + model (torch.nn.Module): The message passing model. + *args: Arguments passed to the model. + **kwargs (optional): Additional keyword arguments passed to the model. + """ + from torch_geometric.nn import MessagePassing + + embeddings: List[Tensor] = [] + + def hook(model: torch.nn.Module, inputs: Any, outputs: Any): + # Clone output in case it will be later modified in-place: + outputs = outputs[0] if isinstance(outputs, tuple) else outputs + assert isinstance(outputs, Tensor) + embeddings.append(outputs.clone()) + + hook_handles = [] + for module in model.modules(): # Register forward hooks: + if isinstance(module, MessagePassing): + hook_handles.append(module.register_forward_hook(hook)) + + if len(hook_handles) == 0: + warnings.warn("The 'model' does not have any 'MessagePassing' layers") + + training = model.training + model.eval() + with torch.no_grad(): + model(*args, **kwargs) + model.train(training) + + for handle in hook_handles: # Remove hooks: + handle.remove() + + return embeddings + + +class PGExplainer(ExplainerAlgorithm): + r"""The PGExplainer model from the `"Parameterized Explainer for Graph + Neural Network" `_ paper. + Internally, it utilizes a neural network to identify subgraph structures + that play a crucial role in the predictions made by a GNN. + Importantly, the :class:`PGExplainer` needs to be trained via + :meth:`~PGExplainer.train` before being able to generate explanations: + + .. code-block:: python + + explainer = Explainer( + model=model, + algorithm=PGExplainer(epochs=30, lr=0.003), + explanation_type='phenomenon', + edge_mask_type='object', + model_config=ModelConfig(...), + ) + + # Train against a variety of node-level or graph-level predictions: + for epoch in range(30): + for index in [...]: # Indices to train against. + loss = explainer.algorithm.train(epoch, model, x, edge_index, + target=target, index=index) + + # Get the final explanations: + explanation = explainer(x, edge_index, target=target, index=0) + + Args: + epochs (int): The number of epochs to train. + lr (float, optional): The learning rate to apply. + (default: :obj:`0.003`). + **kwargs (optional): Additional hyper-parameters to override default + settings in + :attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`. + """ + + coeffs = { + "edge_size": 0.05, + "edge_ent": 1.0, + "temp": [5.0, 2.0], + "bias": 0.0, + } + + def __init__(self, epochs: int, lr: float = 0.003, **kwargs): + super().__init__() + self.epochs = epochs + self.lr = lr + self.coeffs.update(kwargs) + + self.mlp = Sequential( + Linear(-1, 64), + ReLU(), + Linear(64, 1), + ) + self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr) + self._curr_epoch = -1 + + def _get_hard_masks( + model: torch.nn.Module, + index: Optional[Union[int, Tensor]], + edge_index: Tensor, + num_nodes: int, + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + r"""Returns hard node and edge masks that only include the nodes and + edges visited during message passing.""" + if index is None: + return None, None # Consider all nodes and edges. + + index, _, _, edge_mask = k_hop_subgraph( + index, + num_hops=self._num_hops(model), + edge_index=edge_index, + num_nodes=num_nodes, + flow=self._flow(model), + ) + + node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool) + node_mask[index] = True + + return node_mask, edge_mask + + def reset_parameters(self): + reset(self.mlp) + + def train( + self, + epoch: int, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ): + r"""Trains the underlying explainer model. + Needs to be called before being able to make predictions. + + Args: + epoch (int): The current epoch of the training phase. + model (torch.nn.Module): The model to explain. + x (torch.Tensor): The input node features of a + homogeneous graph. + edge_index (torch.Tensor): The input edge indices of a homogeneous + graph. + target (torch.Tensor): The target of the model. + index (int or torch.Tensor, optional): The index of the model + output to explain. Needs to be a single index. + (default: :obj:`None`) + **kwargs (optional): Additional keyword arguments passed to + :obj:`model`. + """ + + z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1] + + self.optimizer.zero_grad() + temperature = self._get_temperature(epoch) + + inputs = self._get_inputs(z, edge_index, index) + logits = self.mlp(inputs).view(-1) + edge_mask = self._concrete_sample(logits, temperature) + set_masks(model, edge_mask, edge_index, apply_sigmoid=True) + + if self.model_config.task_level == ModelTaskLevel.node: + _, hard_edge_mask = self._get_hard_masks( + model, index, edge_index, num_nodes=x.size(0) + ) + edge_mask = edge_mask[hard_edge_mask] + + y_hat, y = model(x, edge_index, **kwargs), target + + if index is not None: + y_hat, y = y_hat[index], y[index] + + loss = self._loss(y_hat, y, edge_mask) + loss.backward() + self.optimizer.step() + + clear_masks(model) + self._curr_epoch = epoch + + return float(loss) + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ) -> Explanation: + if self._curr_epoch < self.epochs - 1: # Safety check: + raise ValueError( + f"'{self.__class__.__name__}' is not yet fully " + f"trained (got {self._curr_epoch + 1} epochs " + f"from {self.epochs} epochs). Please first train " + f"the underlying explainer model by running " + f"`explainer.algorithm.train(...)`." + ) + + hard_edge_mask = None + _, hard_edge_mask = self._get_hard_masks( + model, index, edge_index, num_nodes=x.size(0) + ) + + for epoch in range(self.epochs): + loss = self.train(epoch, model, x, edge_index, target=target, index=index) + z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1] + + inputs = self._get_inputs(z, edge_index, index) + logits = self.mlp(inputs).view(-1) + + edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True) + + return Explanation(edge_mask=edge_mask) + + ########################################################################### + + def _get_inputs( + self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None + ) -> Tensor: + zs = [embedding[edge_index[0]], embedding[edge_index[1]]] + if self.model_config.task_level == ModelTaskLevel.node: + assert index is not None + zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1)) + return torch.cat(zs, dim=-1) + + def _get_temperature(self, epoch: int) -> float: + temp = self.coeffs["temp"] + return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs) + + def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor: + bias = self.coeffs["bias"] + eps = (1 - 2 * bias) * torch.rand_like(logits) + bias + return (eps.log() - (1 - eps).log() + logits) / temperature + + def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor: + if self.model_config.mode == ModelMode.binary_classification: + loss = self._loss_binary_classification(y_hat, y) + elif self.model_config.mode == ModelMode.multiclass_classification: + loss = self._loss_multiclass_classification(y_hat, y) + elif self.model_config.mode == ModelMode.regression: + loss = self._loss_regression(y_hat, y) + + # Regularization loss: + mask = edge_mask.sigmoid() + size_loss = mask.sum() * self.coeffs["edge_size"] + mask = 0.99 * mask + 0.005 + mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log() + mask_ent_loss = mask_ent.mean() * self.coeffs["edge_ent"] + + return loss + size_loss + mask_ent_loss + + +class GNNExplainer(ExplainerAlgorithm): + r"""The GNN-Explainer model from the `"GNNExplainer: Generating + Explanations for Graph Neural Networks" + `_ paper for identifying compact subgraph + structures and node features that play a crucial role in the predictions + made by a GNN. + + .. note:: + + For an example of using :class:`GNNExplainer`, see + `examples/gnn_explainer.py `_ and + `examples/gnn_explainer_ba_shapes.py `_. + + Args: + epochs (int, optional): The number of epochs to train. + (default: :obj:`100`) + lr (float, optional): The learning rate to apply. + (default: :obj:`0.01`) + **kwargs (optional): Additional hyper-parameters to override default + settings in + :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`. + """ + + coeffs = { + "edge_size": 0.005, + "edge_reduction": "sum", + "node_feat_size": 1.0, + "node_feat_reduction": "mean", + "edge_ent": 1.0, + "node_feat_ent": 0.1, + "EPS": 1e-15, + } + + def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs): + super().__init__() + self.epochs = epochs + self.lr = lr + self.coeffs.update(kwargs) + + self.node_mask = self.edge_mask = None + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ) -> Explanation: + + hard_node_mask = hard_edge_mask = None + hard_node_mask, hard_edge_mask = self._get_hard_masks( + model, index, edge_index, num_nodes=x.size(0) + ) + + self._train(model, x, edge_index, target=target, index=index, **kwargs) + + node_mask = self._post_process_mask( + self.node_mask, hard_node_mask, apply_sigmoid=True + ) + edge_mask = self._post_process_mask( + self.edge_mask, hard_edge_mask, apply_sigmoid=True + ) + + return Explanation(node_mask=node_mask, edge_mask=edge_mask) + + def _train( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ): + self._initialize_masks(x, edge_index) + + parameters = [self.node_mask] # We always learn a node mask. + if self.explainer_config.edge_mask_type is not None: + set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True) + parameters.append(self.edge_mask) + + optimizer = torch.optim.Adam(parameters, lr=self.lr) + + for _ in range(self.epochs): + optimizer.zero_grad() + + h = x * self.node_mask.sigmoid() + y_hat, y = model(h, edge_index, **kwargs), target + + if index is not None: + y_hat, y = y_hat[index], y[index] + + loss = self._loss(y_hat, y) + + loss.backward() + optimizer.step() + + def _initialize_masks(self, x: Tensor, edge_index: Tensor): + node_mask_type = self.explainer_config.node_mask_type + edge_mask_type = self.explainer_config.edge_mask_type + + device = x.device + (N, F), E = x.size(), edge_index.size(1) + + std = 0.1 + if node_mask_type == MaskType.object: + self.node_mask = Parameter(torch.randn(N, 1, device=device) * std) + elif node_mask_type == MaskType.attributes: + self.node_mask = Parameter(torch.randn(N, F, device=device) * std) + elif node_mask_type == MaskType.common_attributes: + self.node_mask = Parameter(torch.randn(1, F, device=device) * std) + else: + assert False + + if edge_mask_type == MaskType.object: + std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N)) + self.edge_mask = Parameter(torch.randn(E, device=device) * std) + elif edge_mask_type is not None: + assert False + + def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor: + if self.model_config.mode == ModelMode.binary_classification: + loss = self._loss_binary_classification(y_hat, y) + elif self.model_config.mode == ModelMode.multiclass_classification: + loss = self._loss_multiclass_classification(y_hat, y) + elif self.model_config.mode == ModelMode.regression: + loss = self._loss_regression(y_hat, y) + else: + assert False + + if self.explainer_config.edge_mask_type is not None: + m = self.edge_mask.sigmoid() + edge_reduce = getattr(torch, self.coeffs["edge_reduction"]) + loss = loss + self.coeffs["edge_size"] * edge_reduce(m) + ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log( + 1 - m + self.coeffs["EPS"] + ) + loss = loss + self.coeffs["edge_ent"] * ent.mean() + + m = self.node_mask.sigmoid() + node_feat_reduce = getattr(torch, self.coeffs["node_feat_reduction"]) + loss = loss + self.coeffs["node_feat_size"] * node_feat_reduce(m) + ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log( + 1 - m + self.coeffs["EPS"] + ) + loss = loss + self.coeffs["node_feat_ent"] * ent.mean() + + return loss + + def _clean_model(self, model): + clear_masks(model) + self.node_mask = None + self.edge_mask = None + + +def load_GNNExplainer(): + return GNNExplainer() + + +def load_PGExplainer(): + return PGExplainer(epochs=30) + + +class PYGWrapper(ExplainerAlgorithm): + def __init__(self, name, **kwargs): + super().__init__() + self.name = name + + def supports(self) -> bool: + task_level = self.model_config.task_level + if self.name == "PGExplainer": + explanation_type = self.explainer_config.explanation_type + if explanation_type != ExplanationType.phenomenon: + logging.error( + f"'{self.__class__.__name__}' only supports " + f"phenomenon explanations " + f"got (`explanation_type={explanation_type.value}`)" + ) + return False + + task_level = self.model_config.task_level + if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}: + logging.error( + f"'{self.__class__.__name__}' only supports " + f"node-level or graph-level explanations " + f"got (`task_level={task_level.value}`)" + ) + return False + + node_mask_type = self.explainer_config.node_mask_type + if node_mask_type is not None: + logging.error( + f"'{self.__class__.__name__}' does not support " + f"explaining input node features " + f"got (`node_mask_type={node_mask_type.value}`)" + ) + return False + elif self.name == "GNNExplainer": + pass + + return True + + def _load_pyg_method(self): + if self.name == "GNNExplainer": + return load_GNNExplainer() + if self.name == "PGExplainer": + return load_PGExplainer() + + def _parse_attr(self, attr): + data = attr.to_dict() + node_mask = data.get("node_mask") + edge_mask = data.get("edge_mask") + node_feat_mask = data.get("node_feat_mask") + edge_feat_mask = data.get("edge_feat_mask") + return node_mask, edge_mask, node_feat_mask, edge_feat_mask + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ): + self.pyg_method = self._load_pyg_method() + attr = self.pyg_method.forward( + model=model, + x=x, + edge_index=edge_index, + index=index, + target=target, + ) + + node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr) + return Explanation( + x=x, + edge_index=edge_index, + y=target, + edge_mask=edge_mask, + node_mask=node_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + )