Fixing import bug
This commit is contained in:
		
							parent
							
								
									d1c2565003
								
							
						
					
					
						commit
						7da28de955
					
				
					 1 changed files with 536 additions and 0 deletions
				
			
		
							
								
								
									
										536
									
								
								explaining_framework/explainers/wrappers/from_pyg.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										536
									
								
								explaining_framework/explainers/wrappers/from_pyg.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,536 @@ | ||||||
|  | import logging | ||||||
|  | import warnings | ||||||
|  | from math import sqrt | ||||||
|  | from typing import Any, List, Optional, Tuple, Union | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import Tensor | ||||||
|  | from torch.nn import CrossEntropyLoss, MSELoss, ReLU, Sequential | ||||||
|  | from torch.nn.parameter import Parameter | ||||||
|  | from torch_geometric.data import Data | ||||||
|  | from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig | ||||||
|  | from torch_geometric.explain.algorithm import ExplainerAlgorithm | ||||||
|  | from torch_geometric.explain.algorithm.base import ExplainerAlgorithm | ||||||
|  | from torch_geometric.explain.algorithm.utils import clear_masks, set_masks | ||||||
|  | from torch_geometric.explain.config import (ExplainerConfig, ExplanationType, | ||||||
|  |                                             MaskType, ModelConfig, ModelMode, | ||||||
|  |                                             ModelTaskLevel) | ||||||
|  | from torch_geometric.nn import Linear | ||||||
|  | from torch_geometric.nn.inits import reset | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_message_passing_embeddings( | ||||||
|  |     model: torch.nn.Module, | ||||||
|  |     *args, | ||||||
|  |     **kwargs, | ||||||
|  | ) -> List[Tensor]: | ||||||
|  |     """Returns the output embeddings of all | ||||||
|  |     :class:`~torch_geometric.nn.conv.MessagePassing` layers in | ||||||
|  |     :obj:`model`. | ||||||
|  | 
 | ||||||
|  |     Internally, this method registers forward hooks on all | ||||||
|  |     :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`, | ||||||
|  |     and runs the forward pass of the :obj:`model` by calling | ||||||
|  |     :obj:`model(*args, **kwargs)`. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         model (torch.nn.Module): The message passing model. | ||||||
|  |         *args: Arguments passed to the model. | ||||||
|  |         **kwargs (optional): Additional keyword arguments passed to the model. | ||||||
|  |     """ | ||||||
|  |     from torch_geometric.nn import MessagePassing | ||||||
|  | 
 | ||||||
|  |     embeddings: List[Tensor] = [] | ||||||
|  | 
 | ||||||
|  |     def hook(model: torch.nn.Module, inputs: Any, outputs: Any): | ||||||
|  |         # Clone output in case it will be later modified in-place: | ||||||
|  |         outputs = outputs[0] if isinstance(outputs, tuple) else outputs | ||||||
|  |         assert isinstance(outputs, Tensor) | ||||||
|  |         embeddings.append(outputs.clone()) | ||||||
|  | 
 | ||||||
|  |     hook_handles = [] | ||||||
|  |     for module in model.modules():  # Register forward hooks: | ||||||
|  |         if isinstance(module, MessagePassing): | ||||||
|  |             hook_handles.append(module.register_forward_hook(hook)) | ||||||
|  | 
 | ||||||
|  |     if len(hook_handles) == 0: | ||||||
|  |         warnings.warn("The 'model' does not have any 'MessagePassing' layers") | ||||||
|  | 
 | ||||||
|  |     training = model.training | ||||||
|  |     model.eval() | ||||||
|  |     with torch.no_grad(): | ||||||
|  |         model(*args, **kwargs) | ||||||
|  |     model.train(training) | ||||||
|  | 
 | ||||||
|  |     for handle in hook_handles:  # Remove hooks: | ||||||
|  |         handle.remove() | ||||||
|  | 
 | ||||||
|  |     return embeddings | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PGExplainer(ExplainerAlgorithm): | ||||||
|  |     r"""The PGExplainer model from the `"Parameterized Explainer for Graph | ||||||
|  |     Neural Network" <https://arxiv.org/abs/2011.04573>`_ paper. | ||||||
|  |     Internally, it utilizes a neural network to identify subgraph structures | ||||||
|  |     that play a crucial role in the predictions made by a GNN. | ||||||
|  |     Importantly, the :class:`PGExplainer` needs to be trained via | ||||||
|  |     :meth:`~PGExplainer.train` before being able to generate explanations: | ||||||
|  | 
 | ||||||
|  |     .. code-block:: python | ||||||
|  | 
 | ||||||
|  |         explainer = Explainer( | ||||||
|  |             model=model, | ||||||
|  |             algorithm=PGExplainer(epochs=30, lr=0.003), | ||||||
|  |             explanation_type='phenomenon', | ||||||
|  |             edge_mask_type='object', | ||||||
|  |             model_config=ModelConfig(...), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # Train against a variety of node-level or graph-level predictions: | ||||||
|  |         for epoch in range(30): | ||||||
|  |             for index in [...]:  # Indices to train against. | ||||||
|  |                 loss = explainer.algorithm.train(epoch, model, x, edge_index, | ||||||
|  |                                                  target=target, index=index) | ||||||
|  | 
 | ||||||
|  |         # Get the final explanations: | ||||||
|  |         explanation = explainer(x, edge_index, target=target, index=0) | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         epochs (int): The number of epochs to train. | ||||||
|  |         lr (float, optional): The learning rate to apply. | ||||||
|  |             (default: :obj:`0.003`). | ||||||
|  |         **kwargs (optional): Additional hyper-parameters to override default | ||||||
|  |             settings in | ||||||
|  |             :attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     coeffs = { | ||||||
|  |         "edge_size": 0.05, | ||||||
|  |         "edge_ent": 1.0, | ||||||
|  |         "temp": [5.0, 2.0], | ||||||
|  |         "bias": 0.0, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     def __init__(self, epochs: int, lr: float = 0.003, **kwargs): | ||||||
|  |         super().__init__() | ||||||
|  |         self.epochs = epochs | ||||||
|  |         self.lr = lr | ||||||
|  |         self.coeffs.update(kwargs) | ||||||
|  | 
 | ||||||
|  |         self.mlp = Sequential( | ||||||
|  |             Linear(-1, 64), | ||||||
|  |             ReLU(), | ||||||
|  |             Linear(64, 1), | ||||||
|  |         ) | ||||||
|  |         self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr) | ||||||
|  |         self._curr_epoch = -1 | ||||||
|  | 
 | ||||||
|  |     def _get_hard_masks( | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         index: Optional[Union[int, Tensor]], | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         num_nodes: int, | ||||||
|  |     ) -> Tuple[Optional[Tensor], Optional[Tensor]]: | ||||||
|  |         r"""Returns hard node and edge masks that only include the nodes and | ||||||
|  |         edges visited during message passing.""" | ||||||
|  |         if index is None: | ||||||
|  |             return None, None  # Consider all nodes and edges. | ||||||
|  | 
 | ||||||
|  |         index, _, _, edge_mask = k_hop_subgraph( | ||||||
|  |             index, | ||||||
|  |             num_hops=self._num_hops(model), | ||||||
|  |             edge_index=edge_index, | ||||||
|  |             num_nodes=num_nodes, | ||||||
|  |             flow=self._flow(model), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool) | ||||||
|  |         node_mask[index] = True | ||||||
|  | 
 | ||||||
|  |         return node_mask, edge_mask | ||||||
|  | 
 | ||||||
|  |     def reset_parameters(self): | ||||||
|  |         reset(self.mlp) | ||||||
|  | 
 | ||||||
|  |     def train( | ||||||
|  |         self, | ||||||
|  |         epoch: int, | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         x: Tensor, | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         *, | ||||||
|  |         target: Tensor, | ||||||
|  |         index: Optional[Union[int, Tensor]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         r"""Trains the underlying explainer model. | ||||||
|  |         Needs to be called before being able to make predictions. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             epoch (int): The current epoch of the training phase. | ||||||
|  |             model (torch.nn.Module): The model to explain. | ||||||
|  |             x (torch.Tensor): The input node features of a | ||||||
|  |                 homogeneous graph. | ||||||
|  |             edge_index (torch.Tensor): The input edge indices of a homogeneous | ||||||
|  |                 graph. | ||||||
|  |             target (torch.Tensor): The target of the model. | ||||||
|  |             index (int or torch.Tensor, optional): The index of the model | ||||||
|  |                 output to explain. Needs to be a single index. | ||||||
|  |                 (default: :obj:`None`) | ||||||
|  |             **kwargs (optional): Additional keyword arguments passed to | ||||||
|  |                 :obj:`model`. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1] | ||||||
|  | 
 | ||||||
|  |         self.optimizer.zero_grad() | ||||||
|  |         temperature = self._get_temperature(epoch) | ||||||
|  | 
 | ||||||
|  |         inputs = self._get_inputs(z, edge_index, index) | ||||||
|  |         logits = self.mlp(inputs).view(-1) | ||||||
|  |         edge_mask = self._concrete_sample(logits, temperature) | ||||||
|  |         set_masks(model, edge_mask, edge_index, apply_sigmoid=True) | ||||||
|  | 
 | ||||||
|  |         if self.model_config.task_level == ModelTaskLevel.node: | ||||||
|  |             _, hard_edge_mask = self._get_hard_masks( | ||||||
|  |                 model, index, edge_index, num_nodes=x.size(0) | ||||||
|  |             ) | ||||||
|  |             edge_mask = edge_mask[hard_edge_mask] | ||||||
|  | 
 | ||||||
|  |         y_hat, y = model(x, edge_index, **kwargs), target | ||||||
|  | 
 | ||||||
|  |         if index is not None: | ||||||
|  |             y_hat, y = y_hat[index], y[index] | ||||||
|  | 
 | ||||||
|  |         loss = self._loss(y_hat, y, edge_mask) | ||||||
|  |         loss.backward() | ||||||
|  |         self.optimizer.step() | ||||||
|  | 
 | ||||||
|  |         clear_masks(model) | ||||||
|  |         self._curr_epoch = epoch | ||||||
|  | 
 | ||||||
|  |         return float(loss) | ||||||
|  | 
 | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         x: Tensor, | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         *, | ||||||
|  |         target: Tensor, | ||||||
|  |         index: Optional[Union[int, Tensor]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ) -> Explanation: | ||||||
|  |         if self._curr_epoch < self.epochs - 1:  # Safety check: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"'{self.__class__.__name__}' is not yet fully " | ||||||
|  |                 f"trained (got {self._curr_epoch + 1} epochs " | ||||||
|  |                 f"from {self.epochs} epochs). Please first train " | ||||||
|  |                 f"the underlying explainer model by running " | ||||||
|  |                 f"`explainer.algorithm.train(...)`." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         hard_edge_mask = None | ||||||
|  |         _, hard_edge_mask = self._get_hard_masks( | ||||||
|  |             model, index, edge_index, num_nodes=x.size(0) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         for epoch in range(self.epochs): | ||||||
|  |             loss = self.train(epoch, model, x, edge_index, target=target, index=index) | ||||||
|  |         z = get_message_passing_embeddings(model, x, edge_index, **kwargs)[-1] | ||||||
|  | 
 | ||||||
|  |         inputs = self._get_inputs(z, edge_index, index) | ||||||
|  |         logits = self.mlp(inputs).view(-1) | ||||||
|  | 
 | ||||||
|  |         edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True) | ||||||
|  | 
 | ||||||
|  |         return Explanation(edge_mask=edge_mask) | ||||||
|  | 
 | ||||||
|  |     ########################################################################### | ||||||
|  | 
 | ||||||
|  |     def _get_inputs( | ||||||
|  |         self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None | ||||||
|  |     ) -> Tensor: | ||||||
|  |         zs = [embedding[edge_index[0]], embedding[edge_index[1]]] | ||||||
|  |         if self.model_config.task_level == ModelTaskLevel.node: | ||||||
|  |             assert index is not None | ||||||
|  |             zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1)) | ||||||
|  |         return torch.cat(zs, dim=-1) | ||||||
|  | 
 | ||||||
|  |     def _get_temperature(self, epoch: int) -> float: | ||||||
|  |         temp = self.coeffs["temp"] | ||||||
|  |         return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs) | ||||||
|  | 
 | ||||||
|  |     def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor: | ||||||
|  |         bias = self.coeffs["bias"] | ||||||
|  |         eps = (1 - 2 * bias) * torch.rand_like(logits) + bias | ||||||
|  |         return (eps.log() - (1 - eps).log() + logits) / temperature | ||||||
|  | 
 | ||||||
|  |     def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor: | ||||||
|  |         if self.model_config.mode == ModelMode.binary_classification: | ||||||
|  |             loss = self._loss_binary_classification(y_hat, y) | ||||||
|  |         elif self.model_config.mode == ModelMode.multiclass_classification: | ||||||
|  |             loss = self._loss_multiclass_classification(y_hat, y) | ||||||
|  |         elif self.model_config.mode == ModelMode.regression: | ||||||
|  |             loss = self._loss_regression(y_hat, y) | ||||||
|  | 
 | ||||||
|  |         # Regularization loss: | ||||||
|  |         mask = edge_mask.sigmoid() | ||||||
|  |         size_loss = mask.sum() * self.coeffs["edge_size"] | ||||||
|  |         mask = 0.99 * mask + 0.005 | ||||||
|  |         mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log() | ||||||
|  |         mask_ent_loss = mask_ent.mean() * self.coeffs["edge_ent"] | ||||||
|  | 
 | ||||||
|  |         return loss + size_loss + mask_ent_loss | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GNNExplainer(ExplainerAlgorithm): | ||||||
|  |     r"""The GNN-Explainer model from the `"GNNExplainer: Generating | ||||||
|  |     Explanations for Graph Neural Networks" | ||||||
|  |     <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph | ||||||
|  |     structures and node features that play a crucial role in the predictions | ||||||
|  |     made by a GNN. | ||||||
|  | 
 | ||||||
|  |     .. note:: | ||||||
|  | 
 | ||||||
|  |         For an example of using :class:`GNNExplainer`, see | ||||||
|  |         `examples/gnn_explainer.py <https://github.com/pyg-team/ | ||||||
|  |         pytorch_geometric/blob/master/examples/gnn_explainer.py>`_ and | ||||||
|  |         `examples/gnn_explainer_ba_shapes.py <https://github.com/pyg-team/ | ||||||
|  |         pytorch_geometric/blob/master/examples/gnn_explainer_ba_shapes.py>`_. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         epochs (int, optional): The number of epochs to train. | ||||||
|  |             (default: :obj:`100`) | ||||||
|  |         lr (float, optional): The learning rate to apply. | ||||||
|  |             (default: :obj:`0.01`) | ||||||
|  |         **kwargs (optional): Additional hyper-parameters to override default | ||||||
|  |             settings in | ||||||
|  |             :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     coeffs = { | ||||||
|  |         "edge_size": 0.005, | ||||||
|  |         "edge_reduction": "sum", | ||||||
|  |         "node_feat_size": 1.0, | ||||||
|  |         "node_feat_reduction": "mean", | ||||||
|  |         "edge_ent": 1.0, | ||||||
|  |         "node_feat_ent": 0.1, | ||||||
|  |         "EPS": 1e-15, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs): | ||||||
|  |         super().__init__() | ||||||
|  |         self.epochs = epochs | ||||||
|  |         self.lr = lr | ||||||
|  |         self.coeffs.update(kwargs) | ||||||
|  | 
 | ||||||
|  |         self.node_mask = self.edge_mask = None | ||||||
|  | 
 | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         x: Tensor, | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         *, | ||||||
|  |         target: Tensor, | ||||||
|  |         index: Optional[Union[int, Tensor]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ) -> Explanation: | ||||||
|  | 
 | ||||||
|  |         hard_node_mask = hard_edge_mask = None | ||||||
|  |         hard_node_mask, hard_edge_mask = self._get_hard_masks( | ||||||
|  |             model, index, edge_index, num_nodes=x.size(0) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self._train(model, x, edge_index, target=target, index=index, **kwargs) | ||||||
|  | 
 | ||||||
|  |         node_mask = self._post_process_mask( | ||||||
|  |             self.node_mask, hard_node_mask, apply_sigmoid=True | ||||||
|  |         ) | ||||||
|  |         edge_mask = self._post_process_mask( | ||||||
|  |             self.edge_mask, hard_edge_mask, apply_sigmoid=True | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return Explanation(node_mask=node_mask, edge_mask=edge_mask) | ||||||
|  | 
 | ||||||
|  |     def _train( | ||||||
|  |         self, | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         x: Tensor, | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         *, | ||||||
|  |         target: Tensor, | ||||||
|  |         index: Optional[Union[int, Tensor]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         self._initialize_masks(x, edge_index) | ||||||
|  | 
 | ||||||
|  |         parameters = [self.node_mask]  # We always learn a node mask. | ||||||
|  |         if self.explainer_config.edge_mask_type is not None: | ||||||
|  |             set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True) | ||||||
|  |             parameters.append(self.edge_mask) | ||||||
|  | 
 | ||||||
|  |         optimizer = torch.optim.Adam(parameters, lr=self.lr) | ||||||
|  | 
 | ||||||
|  |         for _ in range(self.epochs): | ||||||
|  |             optimizer.zero_grad() | ||||||
|  | 
 | ||||||
|  |             h = x * self.node_mask.sigmoid() | ||||||
|  |             y_hat, y = model(h, edge_index, **kwargs), target | ||||||
|  | 
 | ||||||
|  |             if index is not None: | ||||||
|  |                 y_hat, y = y_hat[index], y[index] | ||||||
|  | 
 | ||||||
|  |             loss = self._loss(y_hat, y) | ||||||
|  | 
 | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  | 
 | ||||||
|  |     def _initialize_masks(self, x: Tensor, edge_index: Tensor): | ||||||
|  |         node_mask_type = self.explainer_config.node_mask_type | ||||||
|  |         edge_mask_type = self.explainer_config.edge_mask_type | ||||||
|  | 
 | ||||||
|  |         device = x.device | ||||||
|  |         (N, F), E = x.size(), edge_index.size(1) | ||||||
|  | 
 | ||||||
|  |         std = 0.1 | ||||||
|  |         if node_mask_type == MaskType.object: | ||||||
|  |             self.node_mask = Parameter(torch.randn(N, 1, device=device) * std) | ||||||
|  |         elif node_mask_type == MaskType.attributes: | ||||||
|  |             self.node_mask = Parameter(torch.randn(N, F, device=device) * std) | ||||||
|  |         elif node_mask_type == MaskType.common_attributes: | ||||||
|  |             self.node_mask = Parameter(torch.randn(1, F, device=device) * std) | ||||||
|  |         else: | ||||||
|  |             assert False | ||||||
|  | 
 | ||||||
|  |         if edge_mask_type == MaskType.object: | ||||||
|  |             std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N)) | ||||||
|  |             self.edge_mask = Parameter(torch.randn(E, device=device) * std) | ||||||
|  |         elif edge_mask_type is not None: | ||||||
|  |             assert False | ||||||
|  | 
 | ||||||
|  |     def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor: | ||||||
|  |         if self.model_config.mode == ModelMode.binary_classification: | ||||||
|  |             loss = self._loss_binary_classification(y_hat, y) | ||||||
|  |         elif self.model_config.mode == ModelMode.multiclass_classification: | ||||||
|  |             loss = self._loss_multiclass_classification(y_hat, y) | ||||||
|  |         elif self.model_config.mode == ModelMode.regression: | ||||||
|  |             loss = self._loss_regression(y_hat, y) | ||||||
|  |         else: | ||||||
|  |             assert False | ||||||
|  | 
 | ||||||
|  |         if self.explainer_config.edge_mask_type is not None: | ||||||
|  |             m = self.edge_mask.sigmoid() | ||||||
|  |             edge_reduce = getattr(torch, self.coeffs["edge_reduction"]) | ||||||
|  |             loss = loss + self.coeffs["edge_size"] * edge_reduce(m) | ||||||
|  |             ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log( | ||||||
|  |                 1 - m + self.coeffs["EPS"] | ||||||
|  |             ) | ||||||
|  |             loss = loss + self.coeffs["edge_ent"] * ent.mean() | ||||||
|  | 
 | ||||||
|  |         m = self.node_mask.sigmoid() | ||||||
|  |         node_feat_reduce = getattr(torch, self.coeffs["node_feat_reduction"]) | ||||||
|  |         loss = loss + self.coeffs["node_feat_size"] * node_feat_reduce(m) | ||||||
|  |         ent = -m * torch.log(m + self.coeffs["EPS"]) - (1 - m) * torch.log( | ||||||
|  |             1 - m + self.coeffs["EPS"] | ||||||
|  |         ) | ||||||
|  |         loss = loss + self.coeffs["node_feat_ent"] * ent.mean() | ||||||
|  | 
 | ||||||
|  |         return loss | ||||||
|  | 
 | ||||||
|  |     def _clean_model(self, model): | ||||||
|  |         clear_masks(model) | ||||||
|  |         self.node_mask = None | ||||||
|  |         self.edge_mask = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_GNNExplainer(): | ||||||
|  |     return GNNExplainer() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_PGExplainer(): | ||||||
|  |     return PGExplainer(epochs=30) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PYGWrapper(ExplainerAlgorithm): | ||||||
|  |     def __init__(self, name, **kwargs): | ||||||
|  |         super().__init__() | ||||||
|  |         self.name = name | ||||||
|  | 
 | ||||||
|  |     def supports(self) -> bool: | ||||||
|  |         task_level = self.model_config.task_level | ||||||
|  |         if self.name == "PGExplainer": | ||||||
|  |             explanation_type = self.explainer_config.explanation_type | ||||||
|  |             if explanation_type != ExplanationType.phenomenon: | ||||||
|  |                 logging.error( | ||||||
|  |                     f"'{self.__class__.__name__}' only supports " | ||||||
|  |                     f"phenomenon explanations " | ||||||
|  |                     f"got (`explanation_type={explanation_type.value}`)" | ||||||
|  |                 ) | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|  |             task_level = self.model_config.task_level | ||||||
|  |             if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}: | ||||||
|  |                 logging.error( | ||||||
|  |                     f"'{self.__class__.__name__}' only supports " | ||||||
|  |                     f"node-level or graph-level explanations " | ||||||
|  |                     f"got (`task_level={task_level.value}`)" | ||||||
|  |                 ) | ||||||
|  |                 return False | ||||||
|  | 
 | ||||||
|  |             node_mask_type = self.explainer_config.node_mask_type | ||||||
|  |             if node_mask_type is not None: | ||||||
|  |                 logging.error( | ||||||
|  |                     f"'{self.__class__.__name__}' does not support " | ||||||
|  |                     f"explaining input node features " | ||||||
|  |                     f"got (`node_mask_type={node_mask_type.value}`)" | ||||||
|  |                 ) | ||||||
|  |                 return False | ||||||
|  |         elif self.name == "GNNExplainer": | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |         return True | ||||||
|  | 
 | ||||||
|  |     def _load_pyg_method(self): | ||||||
|  |         if self.name == "GNNExplainer": | ||||||
|  |             return load_GNNExplainer() | ||||||
|  |         if self.name == "PGExplainer": | ||||||
|  |             return load_PGExplainer() | ||||||
|  | 
 | ||||||
|  |     def _parse_attr(self, attr): | ||||||
|  |         data = attr.to_dict() | ||||||
|  |         node_mask = data.get("node_mask") | ||||||
|  |         edge_mask = data.get("edge_mask") | ||||||
|  |         node_feat_mask = data.get("node_feat_mask") | ||||||
|  |         edge_feat_mask = data.get("edge_feat_mask") | ||||||
|  |         return node_mask, edge_mask, node_feat_mask, edge_feat_mask | ||||||
|  | 
 | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         x: Tensor, | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         target: Tensor, | ||||||
|  |         index: Optional[Union[int, Tensor]] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         self.pyg_method = self._load_pyg_method() | ||||||
|  |         attr = self.pyg_method.forward( | ||||||
|  |             model=model, | ||||||
|  |             x=x, | ||||||
|  |             edge_index=edge_index, | ||||||
|  |             index=index, | ||||||
|  |             target=target, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr) | ||||||
|  |         return Explanation( | ||||||
|  |             x=x, | ||||||
|  |             edge_index=edge_index, | ||||||
|  |             y=target, | ||||||
|  |             edge_mask=edge_mask, | ||||||
|  |             node_mask=node_mask, | ||||||
|  |             node_feat_mask=node_feat_mask, | ||||||
|  |             edge_feat_mask=edge_feat_mask, | ||||||
|  |         ) | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison