Github release
This commit is contained in:
commit
5c6cc20870
|
@ -0,0 +1,384 @@
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from networkx import from_numpy_array, pagerank
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import KLDivLoss, Softmax
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explanation
|
||||||
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
|
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
||||||
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
|
ModelConfig, ModelMode,
|
||||||
|
ModelTaskLevel)
|
||||||
|
from torch_geometric.loader import DataLoader
|
||||||
|
from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph
|
||||||
|
|
||||||
|
from eixgnn.shapley import mc_shapley
|
||||||
|
|
||||||
|
|
||||||
|
class EiXGNN(ExplainerAlgorithm):
|
||||||
|
r"""
|
||||||
|
The official EiX-GNN model from the `"EiX-GNN: Concept-level eigencentrality explainer for graph neural
|
||||||
|
networks"<https://arxiv.org/abs/2206.03491>`_ paper for identifying useful pattern from a GNN model adapted to user background.
|
||||||
|
|
||||||
|
The following configurations are currently supported:
|
||||||
|
|
||||||
|
- :class:`torch_geometric.explain.config.ModelConfig`
|
||||||
|
- :attr:`task_level`: :obj:`"graph"`
|
||||||
|
|
||||||
|
- :class:`torch_geometric.explain.config.ExplainerConfig`
|
||||||
|
|
||||||
|
- :attr:`node_mask_type`: :obj:`"object"`, :obj:`"common_attributes"` or :obj:`"attributes"`
|
||||||
|
- :attr:`edge_mask_type`: :obj:`"object"` or :obj:`None`
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
L (int): The number of concept, it needs to be a positive integer.
|
||||||
|
(default: :obj:`60`)
|
||||||
|
p (float): The parameter in [0,1] representing the concept assimibility constraint.
|
||||||
|
(default: :obj:`0.1`)
|
||||||
|
**kwargs (optional): Additional features such as the version of the algorithm or other [TODO: BETTER DESCRIPTION]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
L: int = 30,
|
||||||
|
p: float = 0.2,
|
||||||
|
importance_sampling_strategy: str = "node",
|
||||||
|
domain_similarity: str = "relative_edge_density",
|
||||||
|
signal_similarity: str = "KL",
|
||||||
|
shap_val_approx: int = 100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.L = L
|
||||||
|
self.p = p
|
||||||
|
self.domain_similarity = domain_similarity
|
||||||
|
self.signal_similarity = signal_similarity
|
||||||
|
self.shap_val_approx = shap_val_approx
|
||||||
|
self.importance_sampling_strategy = importance_sampling_strategy
|
||||||
|
self.name = "EIXGNN"
|
||||||
|
|
||||||
|
def _domain_similarity(self, graph: Data) -> float:
|
||||||
|
if self.domain_similarity == "relative_edge_density":
|
||||||
|
if graph.num_edges != 0:
|
||||||
|
return graph.num_edges / (graph.num_nodes * (graph.num_nodes - 1))
|
||||||
|
else:
|
||||||
|
return 1 / (graph.num_nodes * (graph.num_nodes - 1))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.domain_metric} is not supported yet")
|
||||||
|
|
||||||
|
def _signal_similarity(self, graph1: Tensor, graph2: Tensor) -> float:
|
||||||
|
if self.signal_similarity == "KL":
|
||||||
|
kldiv = KLDivLoss(reduction="batchmean")
|
||||||
|
graph_1 = F.log_softmax(graph1, dim=1)
|
||||||
|
graph_2 = F.softmax(graph2, dim=1)
|
||||||
|
return kldiv(graph_1, graph_2).item()
|
||||||
|
elif self.signal_similarity == "KL_sym":
|
||||||
|
kldiv = KLDivLoss(reduction="batchmean")
|
||||||
|
graph_11 = F.log_softmax(graph1, dim=1)
|
||||||
|
graph_12 = F.log_softmax(graph2, dim=1)
|
||||||
|
graph_21 = F.softmax(graph1, dim=1)
|
||||||
|
graph_22 = F.softmax(graph2, dim=1)
|
||||||
|
return (kldiv(graph_11, graph_22) + kldiv(graph_12, graph_21)).item()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{self.domain_metric} is not supported yet")
|
||||||
|
|
||||||
|
def supports(self) -> bool:
|
||||||
|
task_level = self.model_config.task_level
|
||||||
|
if task_level not in [ModelTaskLevel.graph]:
|
||||||
|
logging.error(f"Task level '{task_level.value}' not supported")
|
||||||
|
return False
|
||||||
|
|
||||||
|
edge_mask_type = self.explainer_config.edge_mask_type
|
||||||
|
if edge_mask_type not in [MaskType.object, None]:
|
||||||
|
logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported")
|
||||||
|
return False
|
||||||
|
|
||||||
|
node_mask_type = self.explainer_config.node_mask_type
|
||||||
|
if node_mask_type not in [
|
||||||
|
MaskType.common_attributes,
|
||||||
|
MaskType.object,
|
||||||
|
MaskType.attributes,
|
||||||
|
]:
|
||||||
|
logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.importance_sampling_strategy in [
|
||||||
|
"node",
|
||||||
|
"neighborhood",
|
||||||
|
"no_prior",
|
||||||
|
]:
|
||||||
|
logging.error(
|
||||||
|
f"This node ablation strategy : {node_ablation['strategy']} is not supported yet. No explanation provided."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if not self.domain_similarity in ["relative_edge_density"]:
|
||||||
|
logging.error(
|
||||||
|
f"This domain signal similarity metric : {domain_similarity['metric']} is not supported yet. No explanation provided."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if not self.signal_similarity in ["KL", "KL_sym"]:
|
||||||
|
logging.error(
|
||||||
|
f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
# ADD OTHER CASE
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor:
|
||||||
|
shap = mc_shapley(
|
||||||
|
coalition=subset_node,
|
||||||
|
data=data,
|
||||||
|
value_func=self.model,
|
||||||
|
sample_num=self.shap_val_approx,
|
||||||
|
)
|
||||||
|
return shap
|
||||||
|
|
||||||
|
def get_mc_shapley_concept(self, concept: Data) -> Tensor:
|
||||||
|
shap_val = []
|
||||||
|
for ind in range(concept.num_nodes):
|
||||||
|
coalition = torch.LongTensor([ind]).to(concept.x.device)
|
||||||
|
shap_val.append(self.get_mc_shapley(subset_node=coalition, data=concept))
|
||||||
|
return torch.FloatTensor(shap_val)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
x: Tensor,
|
||||||
|
edge_index: Tensor,
|
||||||
|
target,
|
||||||
|
**kwargs,
|
||||||
|
) -> Explanation:
|
||||||
|
if int(x.shape[0] * self.p) <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provided graph with {x.shape[0]} and parameter p={self.p} produce concept of size {int(x.shape[0]*self.p)}, which is not suitable. Aborting"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
input_graph = Data(x=x, edge_index=edge_index).to(x.device)
|
||||||
|
|
||||||
|
node_prior_distribution = self._compute_node_ablation_prior(x, edge_index)
|
||||||
|
node_prior_distribution = F.softmax(node_prior_distribution, dim=0)
|
||||||
|
node_prior_distribution = node_prior_distribution.detach().cpu().numpy()
|
||||||
|
|
||||||
|
concept_nodes_index = np.random.choice(
|
||||||
|
np.arange(x.shape[0]),
|
||||||
|
size=(self.L, int(self.p * x.shape[0])),
|
||||||
|
p=node_prior_distribution,
|
||||||
|
)
|
||||||
|
indexes = [
|
||||||
|
torch.LongTensor(concept_nodes).to(x.device)
|
||||||
|
for concept_nodes in concept_nodes_index
|
||||||
|
]
|
||||||
|
concepts = [input_graph.subgraph(ind) for ind in indexes]
|
||||||
|
|
||||||
|
A = self._global_concept_similarity_matrix(concepts)
|
||||||
|
pr = self._adjacency_pr(A)
|
||||||
|
shap_val = [self.get_mc_shapley_concept(concept) for concept in concepts]
|
||||||
|
shap_val_ext = self.extend(
|
||||||
|
shap_val, indexes=concept_nodes_index, size=(self.L, x.shape[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
explanation_map = torch.sum(
|
||||||
|
torch.FloatTensor(np.diag(pr) @ shap_val_ext), dim=0
|
||||||
|
).to(x.device)
|
||||||
|
|
||||||
|
edge_mask = None
|
||||||
|
node_feat_mask = None
|
||||||
|
edge_feat_mask = None
|
||||||
|
|
||||||
|
exp = Explanation(
|
||||||
|
x=x,
|
||||||
|
edge_index=edge_index,
|
||||||
|
y=target,
|
||||||
|
node_mask=explanation_map,
|
||||||
|
edge_mask=edge_mask,
|
||||||
|
node_feat_mask=node_feat_mask,
|
||||||
|
edge_feat_mask=edge_feat_mask,
|
||||||
|
shap=torch.FloatTensor(shap_val_ext).to(x.device),
|
||||||
|
indexes=torch.LongTensor(concept_nodes_index).to(x.device),
|
||||||
|
pr=torch.FloatTensor(pr).to(x.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
return exp
|
||||||
|
|
||||||
|
def extend(self, shap_vals: list, indexes: list, size: tuple):
|
||||||
|
extended_map = np.zeros(size)
|
||||||
|
for i in range(indexes.shape[0]):
|
||||||
|
for j in range(indexes.shape[1]):
|
||||||
|
extended_map[i, indexes[i][j]] += abs(shap_vals[i][j])
|
||||||
|
return extended_map
|
||||||
|
|
||||||
|
# def _shapley_value_extended(self, concepts, shap_val, size=None):
|
||||||
|
# extended_shap_val_matrix = np.zeros(size)
|
||||||
|
# nodes_lists = [self._subgraph_node_mapping(concept) for concept in concepts]
|
||||||
|
# for i in range(extended_shap_val_matrix.shape[0]):
|
||||||
|
# concept_shap_val = shap_val[i]
|
||||||
|
# node_list = nodes_lists[i]
|
||||||
|
# for j, val in zip(node_list, concept_shap_val):
|
||||||
|
# extended_shap_val_matrix[i, j] = val
|
||||||
|
# return extended_shap_val_matrix
|
||||||
|
|
||||||
|
def _global_concept_similarity_matrix(self, concepts):
|
||||||
|
A = np.zeros((len(concepts), len(concepts)))
|
||||||
|
concepts_pred = self.model(Batch().from_data_list(concepts))
|
||||||
|
concepts_prob = F.softmax(concepts_pred, dim=1)
|
||||||
|
for i, c1 in enumerate(concepts):
|
||||||
|
for j, c2 in enumerate(concepts):
|
||||||
|
if j >= i:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
dm1 = self._domain_similarity(c1)
|
||||||
|
dm2 = self._domain_similarity(c2)
|
||||||
|
ss = self._signal_similarity(
|
||||||
|
concepts_prob[i].unsqueeze(0), concepts_prob[j].unsqueeze(0)
|
||||||
|
)
|
||||||
|
A[i, j] = (dm1 / dm2) * ss
|
||||||
|
A[j, i] = (dm2 / dm1) * ss
|
||||||
|
return A
|
||||||
|
|
||||||
|
def _adjacency_pr(self, A):
|
||||||
|
G = from_numpy_array(A)
|
||||||
|
pr = list(pagerank(G).values())
|
||||||
|
return pr
|
||||||
|
|
||||||
|
# def _shapley_value_concept(
|
||||||
|
# self,
|
||||||
|
# concept: Data,
|
||||||
|
# ) -> np.ndarray:
|
||||||
|
|
||||||
|
# g_plus_list = []
|
||||||
|
# g_minus_list = []
|
||||||
|
# val_shap = np.zeros(concept.x.shape[0])
|
||||||
|
|
||||||
|
# for node_ind in range(concept.x.shape[0]):
|
||||||
|
# for _ in range(self.shap_val_approx):
|
||||||
|
# perm = torch.randperm(concept.x.shape[0])
|
||||||
|
|
||||||
|
# x_perm = torch.clone(concept.x)
|
||||||
|
# x_perm = x_perm[perm]
|
||||||
|
|
||||||
|
# x_minus = torch.clone(concept.x)
|
||||||
|
# x_plus = torch.clone(concept.x)
|
||||||
|
|
||||||
|
# x_plus = torch.cat(
|
||||||
|
# (x_plus[: node_ind + 1], x_perm[node_ind + 1 :]), axis=0
|
||||||
|
# )
|
||||||
|
# if node_ind == 0:
|
||||||
|
# x_minus = torch.clone(x_perm)
|
||||||
|
# else:
|
||||||
|
# x_minus = torch.cat((x_minus[:node_ind], x_perm[node_ind:]), axis=0)
|
||||||
|
|
||||||
|
# g_plus = concept.__copy__()
|
||||||
|
# g_plus.x = x_plus
|
||||||
|
# g_minus = concept.__copy__()
|
||||||
|
# g_minus.x = x_minus
|
||||||
|
|
||||||
|
# g_plus_list.append(g_plus)
|
||||||
|
# g_minus_list.append(g_minus)
|
||||||
|
|
||||||
|
# g_plus = Batch().from_data_list(g_plus_list)
|
||||||
|
# g_minus = Batch().from_data_list(g_minus_list)
|
||||||
|
# with torch.no_grad():
|
||||||
|
# g_plus = g_plus.to(concept.x.device)
|
||||||
|
# g_minus = g_minus.to(concept.x.device)
|
||||||
|
|
||||||
|
# out_g_plus = self.model(g_plus)
|
||||||
|
# out_g_minus = self.model(g_minus)
|
||||||
|
|
||||||
|
# g_plus_score = F.softmax(out_g_plus, dim=1)
|
||||||
|
# g_minus_score = F.softmax(out_g_minus, dim=1)
|
||||||
|
|
||||||
|
# g_plus_score = g_plus_score.reshape(
|
||||||
|
# (1, concept.x.shape[0], self.shap_val_approx, -1)
|
||||||
|
# )
|
||||||
|
# g_minus_score = g_minus_score.reshape(
|
||||||
|
# (1, concept.x.shape[0], self.shap_val_approx, -1)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# for node_ind in range(concept.x.shape[0]):
|
||||||
|
# score = torch.mean(
|
||||||
|
# torch.FloatTensor(
|
||||||
|
# [
|
||||||
|
# torch.norm(
|
||||||
|
# input=g_plus_score[0, node_ind, i]
|
||||||
|
# - g_minus_score[0, node_ind, i],
|
||||||
|
# p=1,
|
||||||
|
# )
|
||||||
|
# for i in range(self.shap_val_approx)
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# val_shap[node_ind] = score.item()
|
||||||
|
# return val_shap
|
||||||
|
|
||||||
|
# def _subgraph_node_mapping(self, concept):
|
||||||
|
# nodes_list = torch.unique(concept.edge_index)
|
||||||
|
# return nodes_list
|
||||||
|
|
||||||
|
def _compute_node_ablation_prior(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
edge_index: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if self.importance_sampling_strategy == "no_prior":
|
||||||
|
node_importance = torch.ones(x.shape[0]) / x.shape[0]
|
||||||
|
return node_importance.to(x.device)
|
||||||
|
|
||||||
|
pred = self.model(x=x, edge_index=edge_index)
|
||||||
|
node_importance = torch.zeros(x.shape[0])
|
||||||
|
for node_index in range(x.shape[0]):
|
||||||
|
if self.importance_sampling_strategy == "node":
|
||||||
|
mask = index_to_mask(torch.LongTensor([node_index]), size=x.shape[0])
|
||||||
|
if self.importance_sampling_strategy == "neighborhood":
|
||||||
|
neighborhood_index, _, _, _ = k_hop_subgraph(
|
||||||
|
node_idx=node_index, num_hops=1, edge_index=edge_index
|
||||||
|
)
|
||||||
|
mask = index_to_mask(neighborhood_index, size=x.shape[0])
|
||||||
|
mask = mask <= 0
|
||||||
|
node_mask = torch.arange(x.shape[0]).to(x.device)
|
||||||
|
node_mask = node_mask[mask]
|
||||||
|
edge_index_sub, _ = subgraph(node_mask, edge_index)
|
||||||
|
sub_pred = self.model(x=x, edge_index=edge_index_sub)
|
||||||
|
node_importance[node_index] = torch.norm(pred - sub_pred, p=1)
|
||||||
|
return node_importance.to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch_geometric.datasets import TUDataset
|
||||||
|
|
||||||
|
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||||
|
data = dataset[0]
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||||
|
|
||||||
|
class GCN(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = GCNConv(dataset.num_node_features, 20)
|
||||||
|
self.conv2 = GCNConv(20, dataset.num_classes)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
x, edge_index, batch = data.x, data.edge_index, data.batch
|
||||||
|
x = self.conv1(x, edge_index)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = global_mean_pool(x, batch)
|
||||||
|
x = F.softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = GCN().to(device)
|
||||||
|
model.eval()
|
||||||
|
data = dataset[0].to(device)
|
||||||
|
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||||
|
# explainer = EiXGNN()
|
||||||
|
# explained = explainer.forward(model, data.x, data.edge_index)
|
|
@ -0,0 +1,345 @@
|
||||||
|
import copy
|
||||||
|
from itertools import combinations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy.special import comb
|
||||||
|
from torch_geometric.data import Batch, Data, Dataset
|
||||||
|
from torch_geometric.loader import DataLoader
|
||||||
|
from torch_geometric.utils import to_networkx
|
||||||
|
|
||||||
|
|
||||||
|
def GnnNetsGC2valueFunc(gnnNets, target_class):
|
||||||
|
def value_func(batch):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = gnnNets(data=batch)
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
score = probs[:, target_class]
|
||||||
|
return score
|
||||||
|
|
||||||
|
return value_func
|
||||||
|
|
||||||
|
|
||||||
|
def GnnNetsNC2valueFunc(gnnNets_NC, node_idx, target_class):
|
||||||
|
def value_func(data):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = gnnNets_NC(data=data)
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
# select the corresponding node prob through the node idx on all the sampling graphs
|
||||||
|
batch_size = data.batch.max() + 1
|
||||||
|
probs = probs.reshape(batch_size, -1, probs.shape[-1])
|
||||||
|
score = probs[:, node_idx, target_class]
|
||||||
|
return score
|
||||||
|
|
||||||
|
return value_func
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_build_func(build_method):
|
||||||
|
if build_method.lower() == "zero_filling":
|
||||||
|
return graph_build_zero_filling
|
||||||
|
elif build_method.lower() == "split":
|
||||||
|
return graph_build_split
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MarginalSubgraphDataset(Dataset):
|
||||||
|
def __init__(self, data, exclude_mask, include_mask, subgraph_build_func):
|
||||||
|
self.num_nodes = data.num_nodes
|
||||||
|
self.X = data.x
|
||||||
|
self.edge_index = data.edge_index
|
||||||
|
self.device = self.X.device
|
||||||
|
|
||||||
|
self.label = data.y
|
||||||
|
self.exclude_mask = (
|
||||||
|
torch.tensor(exclude_mask).type(torch.float32).to(self.device)
|
||||||
|
)
|
||||||
|
self.include_mask = (
|
||||||
|
torch.tensor(include_mask).type(torch.float32).to(self.device)
|
||||||
|
)
|
||||||
|
self.subgraph_build_func = subgraph_build_func
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.exclude_mask.shape[0]
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func(
|
||||||
|
self.X, self.edge_index, self.exclude_mask[idx]
|
||||||
|
)
|
||||||
|
include_graph_X, include_graph_edge_index = self.subgraph_build_func(
|
||||||
|
self.X, self.edge_index, self.include_mask[idx]
|
||||||
|
)
|
||||||
|
exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index)
|
||||||
|
include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index)
|
||||||
|
return exclude_data, include_data
|
||||||
|
|
||||||
|
|
||||||
|
def marginal_contribution(
|
||||||
|
data: Data,
|
||||||
|
exclude_mask: np.array,
|
||||||
|
include_mask: np.array,
|
||||||
|
value_func,
|
||||||
|
subgraph_build_func,
|
||||||
|
):
|
||||||
|
"""Calculate the marginal value for each pair. Here exclude_mask and include_mask are node mask."""
|
||||||
|
marginal_subgraph_dataset = MarginalSubgraphDataset(
|
||||||
|
data, exclude_mask, include_mask, subgraph_build_func
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
marginal_subgraph_dataset, batch_size=256, shuffle=False, num_workers=0
|
||||||
|
)
|
||||||
|
|
||||||
|
marginal_contribution_list = []
|
||||||
|
|
||||||
|
for exclude_data, include_data in dataloader:
|
||||||
|
exclude_values = value_func(exclude_data)
|
||||||
|
include_values = value_func(include_data)
|
||||||
|
margin_values = include_values - exclude_values
|
||||||
|
marginal_contribution_list.append(margin_values)
|
||||||
|
|
||||||
|
marginal_contributions = torch.cat(marginal_contribution_list, dim=0)
|
||||||
|
return marginal_contributions
|
||||||
|
|
||||||
|
|
||||||
|
def graph_build_zero_filling(X, edge_index, node_mask: np.array):
|
||||||
|
"""subgraph building through masking the unselected nodes with zero features"""
|
||||||
|
ret_X = X * node_mask.unsqueeze(1)
|
||||||
|
return ret_X, edge_index
|
||||||
|
|
||||||
|
|
||||||
|
def graph_build_split(X, edge_index, node_mask: np.array):
|
||||||
|
"""subgraph building through spliting the selected nodes from the original graph"""
|
||||||
|
ret_X = X
|
||||||
|
row, col = edge_index
|
||||||
|
edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
|
||||||
|
ret_edge_index = edge_index[:, edge_mask]
|
||||||
|
return ret_X, ret_edge_index
|
||||||
|
|
||||||
|
|
||||||
|
def l_shapley(
|
||||||
|
coalition: list,
|
||||||
|
data: Data,
|
||||||
|
local_radius: int,
|
||||||
|
value_func: str,
|
||||||
|
subgraph_building_method="zero_filling",
|
||||||
|
):
|
||||||
|
"""shapley value where players are local neighbor nodes"""
|
||||||
|
graph = to_networkx(data)
|
||||||
|
num_nodes = graph.number_of_nodes()
|
||||||
|
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
||||||
|
|
||||||
|
local_region = copy.copy(coalition)
|
||||||
|
for k in range(local_radius - 1):
|
||||||
|
k_neiborhoood = []
|
||||||
|
for node in local_region:
|
||||||
|
k_neiborhoood += list(graph.neighbors(node))
|
||||||
|
local_region += k_neiborhoood
|
||||||
|
local_region = list(set(local_region))
|
||||||
|
|
||||||
|
set_exclude_masks = []
|
||||||
|
set_include_masks = []
|
||||||
|
nodes_around = [node for node in local_region if node not in coalition]
|
||||||
|
num_nodes_around = len(nodes_around)
|
||||||
|
|
||||||
|
for subset_len in range(0, num_nodes_around + 1):
|
||||||
|
node_exclude_subsets = combinations(nodes_around, subset_len)
|
||||||
|
for node_exclude_subset in node_exclude_subsets:
|
||||||
|
set_exclude_mask = np.ones(num_nodes)
|
||||||
|
set_exclude_mask[local_region] = 0.0
|
||||||
|
if node_exclude_subset:
|
||||||
|
set_exclude_mask[list(node_exclude_subset)] = 1.0
|
||||||
|
set_include_mask = set_exclude_mask.copy()
|
||||||
|
set_include_mask[coalition] = 1.0
|
||||||
|
|
||||||
|
set_exclude_masks.append(set_exclude_mask)
|
||||||
|
set_include_masks.append(set_include_mask)
|
||||||
|
|
||||||
|
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
||||||
|
include_mask = np.stack(set_include_masks, axis=0)
|
||||||
|
num_players = len(nodes_around) + 1
|
||||||
|
num_player_in_set = (
|
||||||
|
num_players - 1 + len(coalition) - (1 - exclude_mask).sum(axis=1)
|
||||||
|
)
|
||||||
|
p = num_players
|
||||||
|
S = num_player_in_set
|
||||||
|
coeffs = torch.tensor(1.0 / comb(p, S) / (p - S + 1e-6))
|
||||||
|
|
||||||
|
marginal_contributions = marginal_contribution(
|
||||||
|
data, exclude_mask, include_mask, value_func, subgraph_build_func
|
||||||
|
)
|
||||||
|
|
||||||
|
l_shapley_value = (marginal_contributions.squeeze().cpu() * coeffs).sum().item()
|
||||||
|
return l_shapley_value
|
||||||
|
|
||||||
|
|
||||||
|
def mc_shapley(
|
||||||
|
coalition: list,
|
||||||
|
data: Data,
|
||||||
|
value_func: str,
|
||||||
|
subgraph_building_method="zero_filling",
|
||||||
|
sample_num=1000,
|
||||||
|
) -> float:
|
||||||
|
"""monte carlo sampling approximation of the shapley value"""
|
||||||
|
subset_build_func = get_graph_build_func(subgraph_building_method)
|
||||||
|
|
||||||
|
num_nodes = data.num_nodes
|
||||||
|
node_indices = np.arange(num_nodes)
|
||||||
|
coalition_placeholder = num_nodes
|
||||||
|
set_exclude_masks = []
|
||||||
|
set_include_masks = []
|
||||||
|
|
||||||
|
for example_idx in range(sample_num):
|
||||||
|
subset_nodes_from = [node for node in node_indices if node not in coalition]
|
||||||
|
random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
|
||||||
|
random_nodes_permutation = np.random.permutation(random_nodes_permutation)
|
||||||
|
split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
|
||||||
|
selected_nodes = random_nodes_permutation[:split_idx]
|
||||||
|
set_exclude_mask = np.zeros(num_nodes)
|
||||||
|
set_exclude_mask[selected_nodes] = 1.0
|
||||||
|
set_include_mask = set_exclude_mask.copy()
|
||||||
|
set_include_mask[coalition] = 1.0
|
||||||
|
|
||||||
|
set_exclude_masks.append(set_exclude_mask)
|
||||||
|
set_include_masks.append(set_include_mask)
|
||||||
|
|
||||||
|
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
||||||
|
include_mask = np.stack(set_include_masks, axis=0)
|
||||||
|
marginal_contributions = marginal_contribution(
|
||||||
|
data, exclude_mask, include_mask, value_func, subset_build_func
|
||||||
|
)
|
||||||
|
mc_shapley_value = marginal_contributions.mean().item()
|
||||||
|
|
||||||
|
return mc_shapley_value
|
||||||
|
|
||||||
|
|
||||||
|
def mc_l_shapley(
|
||||||
|
coalition: list,
|
||||||
|
data: Data,
|
||||||
|
local_radius: int,
|
||||||
|
value_func: str,
|
||||||
|
subgraph_building_method="zero_filling",
|
||||||
|
sample_num=1000,
|
||||||
|
) -> float:
|
||||||
|
"""monte carlo sampling approximation of the l_shapley value"""
|
||||||
|
graph = to_networkx(data)
|
||||||
|
num_nodes = graph.number_of_nodes()
|
||||||
|
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
||||||
|
|
||||||
|
local_region = copy.copy(coalition)
|
||||||
|
for k in range(local_radius - 1):
|
||||||
|
k_neiborhoood = []
|
||||||
|
for node in local_region:
|
||||||
|
k_neiborhoood += list(graph.neighbors(node))
|
||||||
|
local_region += k_neiborhoood
|
||||||
|
local_region = list(set(local_region))
|
||||||
|
|
||||||
|
coalition_placeholder = num_nodes
|
||||||
|
set_exclude_masks = []
|
||||||
|
set_include_masks = []
|
||||||
|
for example_idx in range(sample_num):
|
||||||
|
subset_nodes_from = [node for node in local_region if node not in coalition]
|
||||||
|
random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
|
||||||
|
random_nodes_permutation = np.random.permutation(random_nodes_permutation)
|
||||||
|
split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
|
||||||
|
selected_nodes = random_nodes_permutation[:split_idx]
|
||||||
|
set_exclude_mask = np.ones(num_nodes)
|
||||||
|
set_exclude_mask[local_region] = 0.0
|
||||||
|
set_exclude_mask[selected_nodes] = 1.0
|
||||||
|
set_include_mask = set_exclude_mask.copy()
|
||||||
|
set_include_mask[coalition] = 1.0
|
||||||
|
|
||||||
|
set_exclude_masks.append(set_exclude_mask)
|
||||||
|
set_include_masks.append(set_include_mask)
|
||||||
|
|
||||||
|
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
||||||
|
include_mask = np.stack(set_include_masks, axis=0)
|
||||||
|
marginal_contributions = marginal_contribution(
|
||||||
|
data, exclude_mask, include_mask, value_func, subgraph_build_func
|
||||||
|
)
|
||||||
|
|
||||||
|
mc_l_shapley_value = (marginal_contributions).mean().item()
|
||||||
|
return mc_l_shapley_value
|
||||||
|
|
||||||
|
|
||||||
|
def gnn_score(
|
||||||
|
coalition: list,
|
||||||
|
data: Data,
|
||||||
|
value_func: str,
|
||||||
|
subgraph_building_method="zero_filling",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""the value of subgraph with selected nodes"""
|
||||||
|
num_nodes = data.num_nodes
|
||||||
|
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
||||||
|
mask = torch.zeros(num_nodes).type(torch.float32).to(data.x.device)
|
||||||
|
mask[coalition] = 1.0
|
||||||
|
ret_x, ret_edge_index = subgraph_build_func(data.x, data.edge_index, mask)
|
||||||
|
mask_data = Data(x=ret_x, edge_index=ret_edge_index)
|
||||||
|
mask_data = Batch.from_data_list([mask_data])
|
||||||
|
score = value_func(mask_data)
|
||||||
|
# get the score of predicted class for graph or specific node idx
|
||||||
|
return score.item()
|
||||||
|
|
||||||
|
|
||||||
|
def NC_mc_l_shapley(
|
||||||
|
coalition: list,
|
||||||
|
data: Data,
|
||||||
|
local_radius: int,
|
||||||
|
value_func: str,
|
||||||
|
node_idx: int = -1,
|
||||||
|
subgraph_building_method="zero_filling",
|
||||||
|
sample_num=1000,
|
||||||
|
) -> float:
|
||||||
|
"""monte carlo approximation of l_shapley where the target node is kept in both subgraph"""
|
||||||
|
graph = to_networkx(data)
|
||||||
|
num_nodes = graph.number_of_nodes()
|
||||||
|
subgraph_build_func = get_graph_build_func(subgraph_building_method)
|
||||||
|
|
||||||
|
local_region = copy.copy(coalition)
|
||||||
|
for k in range(local_radius - 1):
|
||||||
|
k_neiborhoood = []
|
||||||
|
for node in local_region:
|
||||||
|
k_neiborhoood += list(graph.neighbors(node))
|
||||||
|
local_region += k_neiborhoood
|
||||||
|
local_region = list(set(local_region))
|
||||||
|
|
||||||
|
coalition_placeholder = num_nodes
|
||||||
|
set_exclude_masks = []
|
||||||
|
set_include_masks = []
|
||||||
|
for example_idx in range(sample_num):
|
||||||
|
subset_nodes_from = [node for node in local_region if node not in coalition]
|
||||||
|
random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder])
|
||||||
|
random_nodes_permutation = np.random.permutation(random_nodes_permutation)
|
||||||
|
split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0]
|
||||||
|
selected_nodes = random_nodes_permutation[:split_idx]
|
||||||
|
set_exclude_mask = np.ones(num_nodes)
|
||||||
|
set_exclude_mask[local_region] = 0.0
|
||||||
|
set_exclude_mask[selected_nodes] = 1.0
|
||||||
|
if node_idx != -1:
|
||||||
|
set_exclude_mask[node_idx] = 1.0
|
||||||
|
set_include_mask = set_exclude_mask.copy()
|
||||||
|
set_include_mask[coalition] = 1.0 # include the node_idx
|
||||||
|
|
||||||
|
set_exclude_masks.append(set_exclude_mask)
|
||||||
|
set_include_masks.append(set_include_mask)
|
||||||
|
|
||||||
|
exclude_mask = np.stack(set_exclude_masks, axis=0)
|
||||||
|
include_mask = np.stack(set_include_masks, axis=0)
|
||||||
|
marginal_contributions = marginal_contribution(
|
||||||
|
data, exclude_mask, include_mask, value_func, subgraph_build_func
|
||||||
|
)
|
||||||
|
|
||||||
|
mc_l_shapley_value = (marginal_contributions).mean().item()
|
||||||
|
return mc_l_shapley_value
|
||||||
|
|
||||||
|
|
||||||
|
def sparsity(coalition: list, data: Data, subgraph_building_method="zero_filling"):
|
||||||
|
if subgraph_building_method == "zero_filling":
|
||||||
|
return 1.0 - len(coalition) / data.num_nodes
|
||||||
|
|
||||||
|
elif subgraph_building_method == "split":
|
||||||
|
row, col = data.edge_index
|
||||||
|
node_mask = torch.zeros(data.x.shape[0])
|
||||||
|
node_mask[coalition] = 1.0
|
||||||
|
edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1)
|
||||||
|
return 1.0 - (edge_mask.sum() / edge_mask.shape[0]).item()
|
Loading…
Reference in New Issue