commit b05e44ba795715587442a7f8f9bdbbd2582fc534 Author: araison Date: Wed Mar 8 18:59:37 2023 +0100 Upload to github diff --git a/README.md b/README.md new file mode 100644 index 0000000..a3242e6 --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +Here is the an example code for using ScoreCAM GNN from the [ScoreCAM GNN : a generalization of an optimal local post-hoc explaining method to any geometric deep learning models](https://arxiv.org/abs/2207.12748) paper + +```python +from torch_geometric.datasets import TUDataset + + dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") + data = dataset[0] + from scgnn.scgnn import SCGNN + + import torch.nn.functional as F + from torch_geometric.nn import GCNConv, global_mean_pool + + + model = Sequential( + "data", + [ + ( + lambda data: (data.x, data.edge_index, data.batch), + "data -> x, edge_index, batch", + ), + (GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"), + (GCNConv(64, dataset.num_classes), "x, edge_index -> x"), + (global_mean_pool, "x, batch -> x"), + ], + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + data = dataset[0].to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + model.eval() + out = model(data) + explainer = SCGNN() + explained = explainer.forward( + model, + data.x, + data.edge_index, + target=2, + interest_map_norm=True, + score_map_norm=True, + ) +``` diff --git a/scgnn/__init__.py b/scgnn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scgnn/__pycache__/__init__.cpython-310.pyc b/scgnn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..316468d Binary files /dev/null and b/scgnn/__pycache__/__init__.cpython-310.pyc differ diff --git a/scgnn/__pycache__/scgnn.cpython-310.pyc b/scgnn/__pycache__/scgnn.cpython-310.pyc new file mode 100644 index 0000000..98e9833 Binary files /dev/null and b/scgnn/__pycache__/scgnn.cpython-310.pyc differ diff --git a/scgnn/scgnn.py b/scgnn/scgnn.py new file mode 100644 index 0000000..a0f07ef --- /dev/null +++ b/scgnn/scgnn.py @@ -0,0 +1,213 @@ +import collections +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module, ModuleList, Softmax +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explanation +from torch_geometric.explain.algorithm.base import ExplainerAlgorithm +# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks +from torch_geometric.explain.config import (ExplainerConfig, MaskType, + ModelConfig, ModelMode, + ModelTaskLevel) +from torch_geometric.nn import MessagePassing, Sequential +from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph + +from .utils.embedding import get_message_passing_embeddings + + +class Hook: + def __init__(self, module): + self.hook = module.register_forward_hook(self.hook_fn) + + def hook_fn(self, module, input, output): + self.input = input + self.output = output + + def close(self): + self.hook.remove() + + +class SCGNN(ExplainerAlgorithm): + r""" + The official implementation of ScoreCAM with GNN flavour [...] + + + + + Args: + + """ + + def __init__( + self, + depth: str = "last", + interest_map_norm: bool = True, + score_map_norm: bool = True, + target_baseline="inference", + **kwargs, + ): + super().__init__() + self.depth = depth + self.interest_map_norm = interest_map_norm + self.score_map_norm = score_map_norm + self.target_baseline = target_baseline + self.name = "SCGNN" + + def supports(self) -> bool: + task_level = self.model_config.task_level + if task_level not in [ModelTaskLevel.graph]: + logging.error(f"Task level '{task_level.value}' not supported") + return False + + edge_mask_type = self.explainer_config.edge_mask_type + if edge_mask_type not in [MaskType.object, None]: + logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported") + return False + + node_mask_type = self.explainer_config.node_mask_type + if node_mask_type not in [ + MaskType.common_attributes, + MaskType.object, + MaskType.attributes, + ]: + logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.") + return False + + return True + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + target, + **kwargs, + ) -> Explanation: + embedding = get_message_passing_embeddings( + model=model, x=x, edge_index=edge_index + ) + + out = model(x=x, edge_index=edge_index) + + if self.target_baseline is None: + c = target + if self.target_baseline == "inference": + c = out.argmax(dim=1).item() + + if self.depth == "last": + score_map = self.get_score_map( + model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c + ) + extra_score_map = None + elif self.depth == "all": + score_map = self.get_score_map( + model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c + ) + extra_score_map = torch.cat( + [ + self.get_score_map( + model=model, x=x, edge_index=edge_index, emb=emb, c=c + ) + for emb in embedding[:-1] + ], + dim=0, + ) + else: + raise ValueError(f"Depth={self.depth} not implemented yet") + + node_mask = score_map + edge_mask = None + node_feat_mask = None + edge_feat_mask = None + + exp = Explanation( + x=x, + edge_index=edge_index, + y=target, + edge_mask=edge_mask, + node_mask=node_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + extra_score_map=extra_score_map, + ) + return exp + + def get_score_map( + self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, emb: Tensor, c: int + ) -> Tensor: + interest_map = emb.clone() + n_nodes, n_features = interest_map.size() + score_map = torch.zeros(n_nodes).to(x.device) + for k in range(n_features): + _x = x.clone() + feat = interest_map[:, k] + if feat.min() == feat.max(): + continue + mask = feat.clone() + if self.interest_map_norm: + mask = (mask - mask.min()).div(mask.max() - mask.min()) + mask = mask.reshape((-1, 1)) + _x = _x * mask + _out = model(x=_x, edge_index=edge_index) + _out = F.softmax(_out, dim=1) + _out = _out.squeeze() + val = float(_out[c]) + score_map = score_map + val * feat + + score_map = F.relu(score_map) + + if self.score_map_norm and score_map.min() != score_map.max(): + score_map = (score_map - score_map.min()).div( + score_map.max() - score_map.min() + ) + return score_map + + +if __name__ == "__main__": + from torch_geometric.datasets import TUDataset + + dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") + data = dataset[0] + + import torch.nn.functional as F + from torch_geometric.nn import GCNConv, global_mean_pool + + # model = torch.nn.ModuleDict( + # { + # "conv1": GCNConv(dataset.num_node_features, 64), + # "conv2": GCNConv(64, dataset.num_classes), + # "gmp": global_mean_pool, + # } + # ) + + model = Sequential( + "data", + [ + ( + lambda data: (data.x, data.edge_index, data.batch), + "data -> x, edge_index, batch", + ), + (GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"), + (GCNConv(64, dataset.num_classes), "x, edge_index -> x"), + (global_mean_pool, "x, batch -> x"), + ], + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + data = dataset[0].to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + model.eval() + out = model(data) + explainer = SCGNN() + explained = explainer.forward( + model, + data.x, + data.edge_index, + target=2, + interest_map_norm=True, + score_map_norm=True, + ) diff --git a/scgnn/utils/__pycache__/embedding.cpython-310.pyc b/scgnn/utils/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000..49f00cc Binary files /dev/null and b/scgnn/utils/__pycache__/embedding.cpython-310.pyc differ diff --git a/scgnn/utils/embedding.py b/scgnn/utils/embedding.py new file mode 100644 index 0000000..a1222cd --- /dev/null +++ b/scgnn/utils/embedding.py @@ -0,0 +1,54 @@ +import warnings +from typing import Any, List + +import torch +from torch import Tensor + + +def get_message_passing_embeddings( + model: torch.nn.Module, + *args, + **kwargs, +) -> List[Tensor]: + """Returns the output embeddings of all + :class:`~torch_geometric.nn.conv.MessagePassing` layers in + :obj:`model`. + + Internally, this method registers forward hooks on all + :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`, + and runs the forward pass of the :obj:`model` by calling + :obj:`model(*args, **kwargs)`. + + Args: + model (torch.nn.Module): The message passing model. + *args: Arguments passed to the model. + **kwargs (optional): Additional keyword arguments passed to the model. + """ + from torch_geometric.nn import MessagePassing + + embeddings: List[Tensor] = [] + + def hook(model: torch.nn.Module, inputs: Any, outputs: Any): + # Clone output in case it will be later modified in-place: + outputs = outputs[0] if isinstance(outputs, tuple) else outputs + assert isinstance(outputs, Tensor) + embeddings.append(outputs.clone()) + + hook_handles = [] + for module in model.modules(): # Register forward hooks: + if isinstance(module, MessagePassing): + hook_handles.append(module.register_forward_hook(hook)) + + if len(hook_handles) == 0: + warnings.warn("The 'model' does not have any 'MessagePassing' layers") + + training = model.training + model.eval() + with torch.no_grad(): + model(*args, **kwargs) + model.train(training) + + for handle in hook_handles: # Remove hooks: + handle.remove() + + return embeddings diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..79bb91f --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup( + name="scgnn", + version="0.1", + description="Official implementation of ScoreCAM GNN for explaining graph neural networks", + packages=["scgnn"], + zip_safe=False, +)