Upload to github

This commit is contained in:
araison 2023-03-08 18:59:37 +01:00
commit b05e44ba79
8 changed files with 317 additions and 0 deletions

41
README.md Normal file
View File

@ -0,0 +1,41 @@
Here is the an example code for using ScoreCAM GNN from the [ScoreCAM GNN : a generalization of an optimal local post-hoc explaining method to any geometric deep learning models](https://arxiv.org/abs/2207.12748) paper
```python
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
data = dataset[0]
from scgnn.scgnn import SCGNN
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
model = Sequential(
"data",
[
(
lambda data: (data.x, data.edge_index, data.batch),
"data -> x, edge_index, batch",
),
(GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
(global_mean_pool, "x, batch -> x"),
],
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.eval()
out = model(data)
explainer = SCGNN()
explained = explainer.forward(
model,
data.x,
data.edge_index,
target=2,
interest_map_norm=True,
score_map_norm=True,
)
```

0
scgnn/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

213
scgnn/scgnn.py Normal file
View File

@ -0,0 +1,213 @@
import collections
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module, ModuleList, Softmax
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
from torch_geometric.nn import MessagePassing, Sequential
from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph
from .utils.embedding import get_message_passing_embeddings
class Hook:
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = input
self.output = output
def close(self):
self.hook.remove()
class SCGNN(ExplainerAlgorithm):
r"""
The official implementation of ScoreCAM with GNN flavour [...]
Args:
"""
def __init__(
self,
depth: str = "last",
interest_map_norm: bool = True,
score_map_norm: bool = True,
target_baseline="inference",
**kwargs,
):
super().__init__()
self.depth = depth
self.interest_map_norm = interest_map_norm
self.score_map_norm = score_map_norm
self.target_baseline = target_baseline
self.name = "SCGNN"
def supports(self) -> bool:
task_level = self.model_config.task_level
if task_level not in [ModelTaskLevel.graph]:
logging.error(f"Task level '{task_level.value}' not supported")
return False
edge_mask_type = self.explainer_config.edge_mask_type
if edge_mask_type not in [MaskType.object, None]:
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [
MaskType.common_attributes,
MaskType.object,
MaskType.attributes,
]:
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
return False
return True
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
target,
**kwargs,
) -> Explanation:
embedding = get_message_passing_embeddings(
model=model, x=x, edge_index=edge_index
)
out = model(x=x, edge_index=edge_index)
if self.target_baseline is None:
c = target
if self.target_baseline == "inference":
c = out.argmax(dim=1).item()
if self.depth == "last":
score_map = self.get_score_map(
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
)
extra_score_map = None
elif self.depth == "all":
score_map = self.get_score_map(
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
)
extra_score_map = torch.cat(
[
self.get_score_map(
model=model, x=x, edge_index=edge_index, emb=emb, c=c
)
for emb in embedding[:-1]
],
dim=0,
)
else:
raise ValueError(f"Depth={self.depth} not implemented yet")
node_mask = score_map
edge_mask = None
node_feat_mask = None
edge_feat_mask = None
exp = Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
extra_score_map=extra_score_map,
)
return exp
def get_score_map(
self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, emb: Tensor, c: int
) -> Tensor:
interest_map = emb.clone()
n_nodes, n_features = interest_map.size()
score_map = torch.zeros(n_nodes).to(x.device)
for k in range(n_features):
_x = x.clone()
feat = interest_map[:, k]
if feat.min() == feat.max():
continue
mask = feat.clone()
if self.interest_map_norm:
mask = (mask - mask.min()).div(mask.max() - mask.min())
mask = mask.reshape((-1, 1))
_x = _x * mask
_out = model(x=_x, edge_index=edge_index)
_out = F.softmax(_out, dim=1)
_out = _out.squeeze()
val = float(_out[c])
score_map = score_map + val * feat
score_map = F.relu(score_map)
if self.score_map_norm and score_map.min() != score_map.max():
score_map = (score_map - score_map.min()).div(
score_map.max() - score_map.min()
)
return score_map
if __name__ == "__main__":
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
data = dataset[0]
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
# model = torch.nn.ModuleDict(
# {
# "conv1": GCNConv(dataset.num_node_features, 64),
# "conv2": GCNConv(64, dataset.num_classes),
# "gmp": global_mean_pool,
# }
# )
model = Sequential(
"data",
[
(
lambda data: (data.x, data.edge_index, data.batch),
"data -> x, edge_index, batch",
),
(GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
(global_mean_pool, "x, batch -> x"),
],
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.eval()
out = model(data)
explainer = SCGNN()
explained = explainer.forward(
model,
data.x,
data.edge_index,
target=2,
interest_map_norm=True,
score_map_norm=True,
)

Binary file not shown.

54
scgnn/utils/embedding.py Normal file
View File

@ -0,0 +1,54 @@
import warnings
from typing import Any, List
import torch
from torch import Tensor
def get_message_passing_embeddings(
model: torch.nn.Module,
*args,
**kwargs,
) -> List[Tensor]:
"""Returns the output embeddings of all
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
:obj:`model`.
Internally, this method registers forward hooks on all
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
and runs the forward pass of the :obj:`model` by calling
:obj:`model(*args, **kwargs)`.
Args:
model (torch.nn.Module): The message passing model.
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the model.
"""
from torch_geometric.nn import MessagePassing
embeddings: List[Tensor] = []
def hook(model: torch.nn.Module, inputs: Any, outputs: Any):
# Clone output in case it will be later modified in-place:
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
assert isinstance(outputs, Tensor)
embeddings.append(outputs.clone())
hook_handles = []
for module in model.modules(): # Register forward hooks:
if isinstance(module, MessagePassing):
hook_handles.append(module.register_forward_hook(hook))
if len(hook_handles) == 0:
warnings.warn("The 'model' does not have any 'MessagePassing' layers")
training = model.training
model.eval()
with torch.no_grad():
model(*args, **kwargs)
model.train(training)
for handle in hook_handles: # Remove hooks:
handle.remove()
return embeddings

9
setup.py Normal file
View File

@ -0,0 +1,9 @@
from setuptools import setup
setup(
name="scgnn",
version="0.1",
description="Official implementation of ScoreCAM GNN for explaining graph neural networks",
packages=["scgnn"],
zip_safe=False,
)