scorecamgnn/scgnn/scgnn.py

167 lines
5.2 KiB
Python

import collections
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module, ModuleList, Softmax
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
from torch_geometric.nn import MessagePassing, Sequential
from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph
from .utils.embedding import get_message_passing_embeddings
class Hook:
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = input
self.output = output
def close(self):
self.hook.remove()
class SCGNN(ExplainerAlgorithm):
r"""
The official implementation of ScoreCAM with GNN flavour [...]
Args:
"""
def __init__(
self,
depth: str = "last",
interest_map_norm: bool = True,
score_map_norm: bool = True,
target_baseline="inference",
**kwargs,
):
super().__init__()
self.depth = depth
self.interest_map_norm = interest_map_norm
self.score_map_norm = score_map_norm
self.target_baseline = target_baseline
self.name = "SCGNN"
def supports(self) -> bool:
task_level = self.model_config.task_level
if task_level not in [ModelTaskLevel.graph]:
logging.error(f"Task level '{task_level.value}' not supported")
return False
edge_mask_type = self.explainer_config.edge_mask_type
if edge_mask_type not in [MaskType.object, None]:
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
return False
node_mask_type = self.explainer_config.node_mask_type
if node_mask_type not in [
MaskType.common_attributes,
MaskType.object,
MaskType.attributes,
]:
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
return False
return True
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
target,
**kwargs,
) -> Explanation:
embedding = get_message_passing_embeddings(
model=model, x=x, edge_index=edge_index
)
out = model(x=x, edge_index=edge_index)
if self.target_baseline is None:
c = target
if self.target_baseline == "inference":
c = out.argmax(dim=1).item()
if self.depth == "last":
score_map = self.get_score_map(
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
)
extra_score_map = None
elif self.depth == "all":
score_map = self.get_score_map(
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
)
extra_score_map = torch.cat(
[
self.get_score_map(
model=model, x=x, edge_index=edge_index, emb=emb, c=c
)
for emb in embedding[:-1]
],
dim=0,
)
else:
raise ValueError(f"Depth={self.depth} not implemented yet")
node_mask = score_map
edge_mask = None
node_feat_mask = None
edge_feat_mask = None
exp = Explanation(
x=x,
edge_index=edge_index,
y=target,
edge_mask=edge_mask,
node_mask=node_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
extra_score_map=extra_score_map,
)
return exp
def get_score_map(
self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, emb: Tensor, c: int
) -> Tensor:
interest_map = emb.clone()
n_nodes, n_features = interest_map.size()
score_map = torch.zeros(n_nodes).to(x.device)
for k in range(n_features):
_x = x.clone()
feat = interest_map[:, k]
if feat.min() == feat.max():
continue
mask = feat.clone()
if self.interest_map_norm:
mask = (mask - mask.min()).div(mask.max() - mask.min())
mask = mask.reshape((-1, 1))
_x = _x * mask
_out = model(x=_x, edge_index=edge_index)
_out = F.softmax(_out, dim=1)
_out = _out.squeeze()
val = float(_out[c])
score_map = score_map + val * feat
score_map = F.relu(score_map)
if self.score_map_norm and score_map.min() != score_map.max():
score_map = (score_map - score_map.min()).div(
score_map.max() - score_map.min()
)
return score_map