Upload to github
This commit is contained in:
commit
b05e44ba79
|
@ -0,0 +1,41 @@
|
|||
Here is the an example code for using ScoreCAM GNN from the [ScoreCAM GNN : a generalization of an optimal local post-hoc explaining method to any geometric deep learning models](https://arxiv.org/abs/2207.12748) paper
|
||||
|
||||
```python
|
||||
from torch_geometric.datasets import TUDataset
|
||||
|
||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||
data = dataset[0]
|
||||
from scgnn.scgnn import SCGNN
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
|
||||
|
||||
model = Sequential(
|
||||
"data",
|
||||
[
|
||||
(
|
||||
lambda data: (data.x, data.edge_index, data.batch),
|
||||
"data -> x, edge_index, batch",
|
||||
),
|
||||
(GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
|
||||
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
|
||||
(global_mean_pool, "x, batch -> x"),
|
||||
],
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
data = dataset[0].to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||
model.eval()
|
||||
out = model(data)
|
||||
explainer = SCGNN()
|
||||
explained = explainer.forward(
|
||||
model,
|
||||
data.x,
|
||||
data.edge_index,
|
||||
target=2,
|
||||
interest_map_norm=True,
|
||||
score_map_norm=True,
|
||||
)
|
||||
```
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,213 @@
|
|||
import collections
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList, Softmax
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explanation
|
||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
||||
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||
ModelConfig, ModelMode,
|
||||
ModelTaskLevel)
|
||||
from torch_geometric.nn import MessagePassing, Sequential
|
||||
from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph
|
||||
|
||||
from .utils.embedding import get_message_passing_embeddings
|
||||
|
||||
|
||||
class Hook:
|
||||
def __init__(self, module):
|
||||
self.hook = module.register_forward_hook(self.hook_fn)
|
||||
|
||||
def hook_fn(self, module, input, output):
|
||||
self.input = input
|
||||
self.output = output
|
||||
|
||||
def close(self):
|
||||
self.hook.remove()
|
||||
|
||||
|
||||
class SCGNN(ExplainerAlgorithm):
|
||||
r"""
|
||||
The official implementation of ScoreCAM with GNN flavour [...]
|
||||
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth: str = "last",
|
||||
interest_map_norm: bool = True,
|
||||
score_map_norm: bool = True,
|
||||
target_baseline="inference",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.interest_map_norm = interest_map_norm
|
||||
self.score_map_norm = score_map_norm
|
||||
self.target_baseline = target_baseline
|
||||
self.name = "SCGNN"
|
||||
|
||||
def supports(self) -> bool:
|
||||
task_level = self.model_config.task_level
|
||||
if task_level not in [ModelTaskLevel.graph]:
|
||||
logging.error(f"Task level '{task_level.value}' not supported")
|
||||
return False
|
||||
|
||||
edge_mask_type = self.explainer_config.edge_mask_type
|
||||
if edge_mask_type not in [MaskType.object, None]:
|
||||
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
|
||||
return False
|
||||
|
||||
node_mask_type = self.explainer_config.node_mask_type
|
||||
if node_mask_type not in [
|
||||
MaskType.common_attributes,
|
||||
MaskType.object,
|
||||
MaskType.attributes,
|
||||
]:
|
||||
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
x: Tensor,
|
||||
edge_index: Tensor,
|
||||
target,
|
||||
**kwargs,
|
||||
) -> Explanation:
|
||||
embedding = get_message_passing_embeddings(
|
||||
model=model, x=x, edge_index=edge_index
|
||||
)
|
||||
|
||||
out = model(x=x, edge_index=edge_index)
|
||||
|
||||
if self.target_baseline is None:
|
||||
c = target
|
||||
if self.target_baseline == "inference":
|
||||
c = out.argmax(dim=1).item()
|
||||
|
||||
if self.depth == "last":
|
||||
score_map = self.get_score_map(
|
||||
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
|
||||
)
|
||||
extra_score_map = None
|
||||
elif self.depth == "all":
|
||||
score_map = self.get_score_map(
|
||||
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
|
||||
)
|
||||
extra_score_map = torch.cat(
|
||||
[
|
||||
self.get_score_map(
|
||||
model=model, x=x, edge_index=edge_index, emb=emb, c=c
|
||||
)
|
||||
for emb in embedding[:-1]
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Depth={self.depth} not implemented yet")
|
||||
|
||||
node_mask = score_map
|
||||
edge_mask = None
|
||||
node_feat_mask = None
|
||||
edge_feat_mask = None
|
||||
|
||||
exp = Explanation(
|
||||
x=x,
|
||||
edge_index=edge_index,
|
||||
y=target,
|
||||
edge_mask=edge_mask,
|
||||
node_mask=node_mask,
|
||||
node_feat_mask=node_feat_mask,
|
||||
edge_feat_mask=edge_feat_mask,
|
||||
extra_score_map=extra_score_map,
|
||||
)
|
||||
return exp
|
||||
|
||||
def get_score_map(
|
||||
self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, emb: Tensor, c: int
|
||||
) -> Tensor:
|
||||
interest_map = emb.clone()
|
||||
n_nodes, n_features = interest_map.size()
|
||||
score_map = torch.zeros(n_nodes).to(x.device)
|
||||
for k in range(n_features):
|
||||
_x = x.clone()
|
||||
feat = interest_map[:, k]
|
||||
if feat.min() == feat.max():
|
||||
continue
|
||||
mask = feat.clone()
|
||||
if self.interest_map_norm:
|
||||
mask = (mask - mask.min()).div(mask.max() - mask.min())
|
||||
mask = mask.reshape((-1, 1))
|
||||
_x = _x * mask
|
||||
_out = model(x=_x, edge_index=edge_index)
|
||||
_out = F.softmax(_out, dim=1)
|
||||
_out = _out.squeeze()
|
||||
val = float(_out[c])
|
||||
score_map = score_map + val * feat
|
||||
|
||||
score_map = F.relu(score_map)
|
||||
|
||||
if self.score_map_norm and score_map.min() != score_map.max():
|
||||
score_map = (score_map - score_map.min()).div(
|
||||
score_map.max() - score_map.min()
|
||||
)
|
||||
return score_map
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch_geometric.datasets import TUDataset
|
||||
|
||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||
data = dataset[0]
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
|
||||
# model = torch.nn.ModuleDict(
|
||||
# {
|
||||
# "conv1": GCNConv(dataset.num_node_features, 64),
|
||||
# "conv2": GCNConv(64, dataset.num_classes),
|
||||
# "gmp": global_mean_pool,
|
||||
# }
|
||||
# )
|
||||
|
||||
model = Sequential(
|
||||
"data",
|
||||
[
|
||||
(
|
||||
lambda data: (data.x, data.edge_index, data.batch),
|
||||
"data -> x, edge_index, batch",
|
||||
),
|
||||
(GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
|
||||
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
|
||||
(global_mean_pool, "x, batch -> x"),
|
||||
],
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
data = dataset[0].to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||
model.eval()
|
||||
out = model(data)
|
||||
explainer = SCGNN()
|
||||
explained = explainer.forward(
|
||||
model,
|
||||
data.x,
|
||||
data.edge_index,
|
||||
target=2,
|
||||
interest_map_norm=True,
|
||||
score_map_norm=True,
|
||||
)
|
Binary file not shown.
|
@ -0,0 +1,54 @@
|
|||
import warnings
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_message_passing_embeddings(
|
||||
model: torch.nn.Module,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> List[Tensor]:
|
||||
"""Returns the output embeddings of all
|
||||
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
|
||||
:obj:`model`.
|
||||
|
||||
Internally, this method registers forward hooks on all
|
||||
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
|
||||
and runs the forward pass of the :obj:`model` by calling
|
||||
:obj:`model(*args, **kwargs)`.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The message passing model.
|
||||
*args: Arguments passed to the model.
|
||||
**kwargs (optional): Additional keyword arguments passed to the model.
|
||||
"""
|
||||
from torch_geometric.nn import MessagePassing
|
||||
|
||||
embeddings: List[Tensor] = []
|
||||
|
||||
def hook(model: torch.nn.Module, inputs: Any, outputs: Any):
|
||||
# Clone output in case it will be later modified in-place:
|
||||
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
|
||||
assert isinstance(outputs, Tensor)
|
||||
embeddings.append(outputs.clone())
|
||||
|
||||
hook_handles = []
|
||||
for module in model.modules(): # Register forward hooks:
|
||||
if isinstance(module, MessagePassing):
|
||||
hook_handles.append(module.register_forward_hook(hook))
|
||||
|
||||
if len(hook_handles) == 0:
|
||||
warnings.warn("The 'model' does not have any 'MessagePassing' layers")
|
||||
|
||||
training = model.training
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model(*args, **kwargs)
|
||||
model.train(training)
|
||||
|
||||
for handle in hook_handles: # Remove hooks:
|
||||
handle.remove()
|
||||
|
||||
return embeddings
|
Loading…
Reference in New Issue