167 lines
5.2 KiB
Python
167 lines
5.2 KiB
Python
|
import collections
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import Tensor
|
||
|
from torch.nn import Module, ModuleList, Softmax
|
||
|
from torch_geometric.data import Batch, Data
|
||
|
from torch_geometric.explain import Explanation
|
||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||
|
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
||
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||
|
ModelConfig, ModelMode,
|
||
|
ModelTaskLevel)
|
||
|
from torch_geometric.nn import MessagePassing, Sequential
|
||
|
from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph
|
||
|
|
||
|
from .utils.embedding import get_message_passing_embeddings
|
||
|
|
||
|
|
||
|
class Hook:
|
||
|
def __init__(self, module):
|
||
|
self.hook = module.register_forward_hook(self.hook_fn)
|
||
|
|
||
|
def hook_fn(self, module, input, output):
|
||
|
self.input = input
|
||
|
self.output = output
|
||
|
|
||
|
def close(self):
|
||
|
self.hook.remove()
|
||
|
|
||
|
|
||
|
class SCGNN(ExplainerAlgorithm):
|
||
|
r"""
|
||
|
The official implementation of ScoreCAM with GNN flavour [...]
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
Args:
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
depth: str = "last",
|
||
|
interest_map_norm: bool = True,
|
||
|
score_map_norm: bool = True,
|
||
|
target_baseline="inference",
|
||
|
**kwargs,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.depth = depth
|
||
|
self.interest_map_norm = interest_map_norm
|
||
|
self.score_map_norm = score_map_norm
|
||
|
self.target_baseline = target_baseline
|
||
|
self.name = "SCGNN"
|
||
|
|
||
|
def supports(self) -> bool:
|
||
|
task_level = self.model_config.task_level
|
||
|
if task_level not in [ModelTaskLevel.graph]:
|
||
|
logging.error(f"Task level '{task_level.value}' not supported")
|
||
|
return False
|
||
|
|
||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||
|
if edge_mask_type not in [MaskType.object, None]:
|
||
|
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
|
||
|
return False
|
||
|
|
||
|
node_mask_type = self.explainer_config.node_mask_type
|
||
|
if node_mask_type not in [
|
||
|
MaskType.common_attributes,
|
||
|
MaskType.object,
|
||
|
MaskType.attributes,
|
||
|
]:
|
||
|
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
model: torch.nn.Module,
|
||
|
x: Tensor,
|
||
|
edge_index: Tensor,
|
||
|
target,
|
||
|
**kwargs,
|
||
|
) -> Explanation:
|
||
|
embedding = get_message_passing_embeddings(
|
||
|
model=model, x=x, edge_index=edge_index
|
||
|
)
|
||
|
|
||
|
out = model(x=x, edge_index=edge_index)
|
||
|
|
||
|
if self.target_baseline is None:
|
||
|
c = target
|
||
|
if self.target_baseline == "inference":
|
||
|
c = out.argmax(dim=1).item()
|
||
|
|
||
|
if self.depth == "last":
|
||
|
score_map = self.get_score_map(
|
||
|
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
|
||
|
)
|
||
|
extra_score_map = None
|
||
|
elif self.depth == "all":
|
||
|
score_map = self.get_score_map(
|
||
|
model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c
|
||
|
)
|
||
|
extra_score_map = torch.cat(
|
||
|
[
|
||
|
self.get_score_map(
|
||
|
model=model, x=x, edge_index=edge_index, emb=emb, c=c
|
||
|
)
|
||
|
for emb in embedding[:-1]
|
||
|
],
|
||
|
dim=0,
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"Depth={self.depth} not implemented yet")
|
||
|
|
||
|
node_mask = score_map
|
||
|
edge_mask = None
|
||
|
node_feat_mask = None
|
||
|
edge_feat_mask = None
|
||
|
|
||
|
exp = Explanation(
|
||
|
x=x,
|
||
|
edge_index=edge_index,
|
||
|
y=target,
|
||
|
edge_mask=edge_mask,
|
||
|
node_mask=node_mask,
|
||
|
node_feat_mask=node_feat_mask,
|
||
|
edge_feat_mask=edge_feat_mask,
|
||
|
extra_score_map=extra_score_map,
|
||
|
)
|
||
|
return exp
|
||
|
|
||
|
def get_score_map(
|
||
|
self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, emb: Tensor, c: int
|
||
|
) -> Tensor:
|
||
|
interest_map = emb.clone()
|
||
|
n_nodes, n_features = interest_map.size()
|
||
|
score_map = torch.zeros(n_nodes).to(x.device)
|
||
|
for k in range(n_features):
|
||
|
_x = x.clone()
|
||
|
feat = interest_map[:, k]
|
||
|
if feat.min() == feat.max():
|
||
|
continue
|
||
|
mask = feat.clone()
|
||
|
if self.interest_map_norm:
|
||
|
mask = (mask - mask.min()).div(mask.max() - mask.min())
|
||
|
mask = mask.reshape((-1, 1))
|
||
|
_x = _x * mask
|
||
|
_out = model(x=_x, edge_index=edge_index)
|
||
|
_out = F.softmax(_out, dim=1)
|
||
|
_out = _out.squeeze()
|
||
|
val = float(_out[c])
|
||
|
score_map = score_map + val * feat
|
||
|
|
||
|
score_map = F.relu(score_map)
|
||
|
if self.score_map_norm and score_map.min() != score_map.max():
|
||
|
score_map = (score_map - score_map.min()).div(
|
||
|
score_map.max() - score_map.min()
|
||
|
)
|
||
|
return score_map
|