commit 5c6cc208700b52263b907c7a8d71ebe63a30da8f Author: araison Date: Mon Mar 6 11:35:30 2023 +0100 Github release diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/eixgnn/__init__.py b/eixgnn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/eixgnn/eixgnn.py b/eixgnn/eixgnn.py new file mode 100644 index 0000000..f1065fd --- /dev/null +++ b/eixgnn/eixgnn.py @@ -0,0 +1,384 @@ +from typing import Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import scipy +import torch +import torch.nn.functional as F +from networkx import from_numpy_array, pagerank +from torch import Tensor +from torch.nn import KLDivLoss, Softmax +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explanation +from torch_geometric.explain.algorithm.base import ExplainerAlgorithm +# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks +from torch_geometric.explain.config import (ExplainerConfig, MaskType, + ModelConfig, ModelMode, + ModelTaskLevel) +from torch_geometric.loader import DataLoader +from torch_geometric.utils import index_to_mask, k_hop_subgraph, subgraph + +from eixgnn.shapley import mc_shapley + + +class EiXGNN(ExplainerAlgorithm): + r""" + The official EiX-GNN model from the `"EiX-GNN: Concept-level eigencentrality explainer for graph neural + networks"`_ paper for identifying useful pattern from a GNN model adapted to user background. + + The following configurations are currently supported: + + - :class:`torch_geometric.explain.config.ModelConfig` + - :attr:`task_level`: :obj:`"graph"` + + - :class:`torch_geometric.explain.config.ExplainerConfig` + + - :attr:`node_mask_type`: :obj:`"object"`, :obj:`"common_attributes"` or :obj:`"attributes"` + - :attr:`edge_mask_type`: :obj:`"object"` or :obj:`None` + + + Args: + L (int): The number of concept, it needs to be a positive integer. + (default: :obj:`60`) + p (float): The parameter in [0,1] representing the concept assimibility constraint. + (default: :obj:`0.1`) + **kwargs (optional): Additional features such as the version of the algorithm or other [TODO: BETTER DESCRIPTION] + """ + + def __init__( + self, + L: int = 30, + p: float = 0.2, + importance_sampling_strategy: str = "node", + domain_similarity: str = "relative_edge_density", + signal_similarity: str = "KL", + shap_val_approx: int = 100, + **kwargs, + ): + super().__init__() + self.L = L + self.p = p + self.domain_similarity = domain_similarity + self.signal_similarity = signal_similarity + self.shap_val_approx = shap_val_approx + self.importance_sampling_strategy = importance_sampling_strategy + self.name = "EIXGNN" + + def _domain_similarity(self, graph: Data) -> float: + if self.domain_similarity == "relative_edge_density": + if graph.num_edges != 0: + return graph.num_edges / (graph.num_nodes * (graph.num_nodes - 1)) + else: + return 1 / (graph.num_nodes * (graph.num_nodes - 1)) + else: + raise ValueError(f"{self.domain_metric} is not supported yet") + + def _signal_similarity(self, graph1: Tensor, graph2: Tensor) -> float: + if self.signal_similarity == "KL": + kldiv = KLDivLoss(reduction="batchmean") + graph_1 = F.log_softmax(graph1, dim=1) + graph_2 = F.softmax(graph2, dim=1) + return kldiv(graph_1, graph_2).item() + elif self.signal_similarity == "KL_sym": + kldiv = KLDivLoss(reduction="batchmean") + graph_11 = F.log_softmax(graph1, dim=1) + graph_12 = F.log_softmax(graph2, dim=1) + graph_21 = F.softmax(graph1, dim=1) + graph_22 = F.softmax(graph2, dim=1) + return (kldiv(graph_11, graph_22) + kldiv(graph_12, graph_21)).item() + + else: + raise ValueError(f"{self.domain_metric} is not supported yet") + + def supports(self) -> bool: + task_level = self.model_config.task_level + if task_level not in [ModelTaskLevel.graph]: + logging.error(f"Task level '{task_level.value}' not supported") + return False + + edge_mask_type = self.explainer_config.edge_mask_type + if edge_mask_type not in [MaskType.object, None]: + logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported") + return False + + node_mask_type = self.explainer_config.node_mask_type + if node_mask_type not in [ + MaskType.common_attributes, + MaskType.object, + MaskType.attributes, + ]: + logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.") + return False + + if not self.importance_sampling_strategy in [ + "node", + "neighborhood", + "no_prior", + ]: + logging.error( + f"This node ablation strategy : {node_ablation['strategy']} is not supported yet. No explanation provided." + ) + return False + if not self.domain_similarity in ["relative_edge_density"]: + logging.error( + f"This domain signal similarity metric : {domain_similarity['metric']} is not supported yet. No explanation provided." + ) + return False + if not self.signal_similarity in ["KL", "KL_sym"]: + logging.error( + f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided." + ) + return False + # ADD OTHER CASE + return True + + def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor: + shap = mc_shapley( + coalition=subset_node, + data=data, + value_func=self.model, + sample_num=self.shap_val_approx, + ) + return shap + + def get_mc_shapley_concept(self, concept: Data) -> Tensor: + shap_val = [] + for ind in range(concept.num_nodes): + coalition = torch.LongTensor([ind]).to(concept.x.device) + shap_val.append(self.get_mc_shapley(subset_node=coalition, data=concept)) + return torch.FloatTensor(shap_val) + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + target, + **kwargs, + ) -> Explanation: + if int(x.shape[0] * self.p) <= 1: + raise ValueError( + f"Provided graph with {x.shape[0]} and parameter p={self.p} produce concept of size {int(x.shape[0]*self.p)}, which is not suitable. Aborting" + ) + + self.model = model + input_graph = Data(x=x, edge_index=edge_index).to(x.device) + + node_prior_distribution = self._compute_node_ablation_prior(x, edge_index) + node_prior_distribution = F.softmax(node_prior_distribution, dim=0) + node_prior_distribution = node_prior_distribution.detach().cpu().numpy() + + concept_nodes_index = np.random.choice( + np.arange(x.shape[0]), + size=(self.L, int(self.p * x.shape[0])), + p=node_prior_distribution, + ) + indexes = [ + torch.LongTensor(concept_nodes).to(x.device) + for concept_nodes in concept_nodes_index + ] + concepts = [input_graph.subgraph(ind) for ind in indexes] + + A = self._global_concept_similarity_matrix(concepts) + pr = self._adjacency_pr(A) + shap_val = [self.get_mc_shapley_concept(concept) for concept in concepts] + shap_val_ext = self.extend( + shap_val, indexes=concept_nodes_index, size=(self.L, x.shape[0]) + ) + + explanation_map = torch.sum( + torch.FloatTensor(np.diag(pr) @ shap_val_ext), dim=0 + ).to(x.device) + + edge_mask = None + node_feat_mask = None + edge_feat_mask = None + + exp = Explanation( + x=x, + edge_index=edge_index, + y=target, + node_mask=explanation_map, + edge_mask=edge_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + shap=torch.FloatTensor(shap_val_ext).to(x.device), + indexes=torch.LongTensor(concept_nodes_index).to(x.device), + pr=torch.FloatTensor(pr).to(x.device), + ) + + return exp + + def extend(self, shap_vals: list, indexes: list, size: tuple): + extended_map = np.zeros(size) + for i in range(indexes.shape[0]): + for j in range(indexes.shape[1]): + extended_map[i, indexes[i][j]] += abs(shap_vals[i][j]) + return extended_map + + # def _shapley_value_extended(self, concepts, shap_val, size=None): + # extended_shap_val_matrix = np.zeros(size) + # nodes_lists = [self._subgraph_node_mapping(concept) for concept in concepts] + # for i in range(extended_shap_val_matrix.shape[0]): + # concept_shap_val = shap_val[i] + # node_list = nodes_lists[i] + # for j, val in zip(node_list, concept_shap_val): + # extended_shap_val_matrix[i, j] = val + # return extended_shap_val_matrix + + def _global_concept_similarity_matrix(self, concepts): + A = np.zeros((len(concepts), len(concepts))) + concepts_pred = self.model(Batch().from_data_list(concepts)) + concepts_prob = F.softmax(concepts_pred, dim=1) + for i, c1 in enumerate(concepts): + for j, c2 in enumerate(concepts): + if j >= i: + continue + else: + dm1 = self._domain_similarity(c1) + dm2 = self._domain_similarity(c2) + ss = self._signal_similarity( + concepts_prob[i].unsqueeze(0), concepts_prob[j].unsqueeze(0) + ) + A[i, j] = (dm1 / dm2) * ss + A[j, i] = (dm2 / dm1) * ss + return A + + def _adjacency_pr(self, A): + G = from_numpy_array(A) + pr = list(pagerank(G).values()) + return pr + + # def _shapley_value_concept( + # self, + # concept: Data, + # ) -> np.ndarray: + + # g_plus_list = [] + # g_minus_list = [] + # val_shap = np.zeros(concept.x.shape[0]) + + # for node_ind in range(concept.x.shape[0]): + # for _ in range(self.shap_val_approx): + # perm = torch.randperm(concept.x.shape[0]) + + # x_perm = torch.clone(concept.x) + # x_perm = x_perm[perm] + + # x_minus = torch.clone(concept.x) + # x_plus = torch.clone(concept.x) + + # x_plus = torch.cat( + # (x_plus[: node_ind + 1], x_perm[node_ind + 1 :]), axis=0 + # ) + # if node_ind == 0: + # x_minus = torch.clone(x_perm) + # else: + # x_minus = torch.cat((x_minus[:node_ind], x_perm[node_ind:]), axis=0) + + # g_plus = concept.__copy__() + # g_plus.x = x_plus + # g_minus = concept.__copy__() + # g_minus.x = x_minus + + # g_plus_list.append(g_plus) + # g_minus_list.append(g_minus) + + # g_plus = Batch().from_data_list(g_plus_list) + # g_minus = Batch().from_data_list(g_minus_list) + # with torch.no_grad(): + # g_plus = g_plus.to(concept.x.device) + # g_minus = g_minus.to(concept.x.device) + + # out_g_plus = self.model(g_plus) + # out_g_minus = self.model(g_minus) + + # g_plus_score = F.softmax(out_g_plus, dim=1) + # g_minus_score = F.softmax(out_g_minus, dim=1) + + # g_plus_score = g_plus_score.reshape( + # (1, concept.x.shape[0], self.shap_val_approx, -1) + # ) + # g_minus_score = g_minus_score.reshape( + # (1, concept.x.shape[0], self.shap_val_approx, -1) + # ) + + # for node_ind in range(concept.x.shape[0]): + # score = torch.mean( + # torch.FloatTensor( + # [ + # torch.norm( + # input=g_plus_score[0, node_ind, i] + # - g_minus_score[0, node_ind, i], + # p=1, + # ) + # for i in range(self.shap_val_approx) + # ] + # ) + # ) + # val_shap[node_ind] = score.item() + # return val_shap + + # def _subgraph_node_mapping(self, concept): + # nodes_list = torch.unique(concept.edge_index) + # return nodes_list + + def _compute_node_ablation_prior( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + + if self.importance_sampling_strategy == "no_prior": + node_importance = torch.ones(x.shape[0]) / x.shape[0] + return node_importance.to(x.device) + + pred = self.model(x=x, edge_index=edge_index) + node_importance = torch.zeros(x.shape[0]) + for node_index in range(x.shape[0]): + if self.importance_sampling_strategy == "node": + mask = index_to_mask(torch.LongTensor([node_index]), size=x.shape[0]) + if self.importance_sampling_strategy == "neighborhood": + neighborhood_index, _, _, _ = k_hop_subgraph( + node_idx=node_index, num_hops=1, edge_index=edge_index + ) + mask = index_to_mask(neighborhood_index, size=x.shape[0]) + mask = mask <= 0 + node_mask = torch.arange(x.shape[0]).to(x.device) + node_mask = node_mask[mask] + edge_index_sub, _ = subgraph(node_mask, edge_index) + sub_pred = self.model(x=x, edge_index=edge_index_sub) + node_importance[node_index] = torch.norm(pred - sub_pred, p=1) + return node_importance.to(x.device) + + +if __name__ == "__main__": + from torch_geometric.datasets import TUDataset + + dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") + data = dataset[0] + + import torch.nn.functional as F + from torch_geometric.nn import GCNConv, global_mean_pool + + class GCN(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(dataset.num_node_features, 20) + self.conv2 = GCNConv(20, dataset.num_classes) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + x = self.conv1(x, edge_index) + x = F.relu(x) + x = global_mean_pool(x, batch) + x = F.softmax(x, dim=1) + return x + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = GCN().to(device) + model.eval() + data = dataset[0].to(device) + # optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + # explainer = EiXGNN() + # explained = explainer.forward(model, data.x, data.edge_index) diff --git a/eixgnn/shapley.py b/eixgnn/shapley.py new file mode 100644 index 0000000..65dec1f --- /dev/null +++ b/eixgnn/shapley.py @@ -0,0 +1,345 @@ +import copy +from itertools import combinations + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.special import comb +from torch_geometric.data import Batch, Data, Dataset +from torch_geometric.loader import DataLoader +from torch_geometric.utils import to_networkx + + +def GnnNetsGC2valueFunc(gnnNets, target_class): + def value_func(batch): + with torch.no_grad(): + logits = gnnNets(data=batch) + probs = F.softmax(logits, dim=-1) + score = probs[:, target_class] + return score + + return value_func + + +def GnnNetsNC2valueFunc(gnnNets_NC, node_idx, target_class): + def value_func(data): + with torch.no_grad(): + logits = gnnNets_NC(data=data) + probs = F.softmax(logits, dim=-1) + # select the corresponding node prob through the node idx on all the sampling graphs + batch_size = data.batch.max() + 1 + probs = probs.reshape(batch_size, -1, probs.shape[-1]) + score = probs[:, node_idx, target_class] + return score + + return value_func + + +def get_graph_build_func(build_method): + if build_method.lower() == "zero_filling": + return graph_build_zero_filling + elif build_method.lower() == "split": + return graph_build_split + else: + raise NotImplementedError + + +class MarginalSubgraphDataset(Dataset): + def __init__(self, data, exclude_mask, include_mask, subgraph_build_func): + self.num_nodes = data.num_nodes + self.X = data.x + self.edge_index = data.edge_index + self.device = self.X.device + + self.label = data.y + self.exclude_mask = ( + torch.tensor(exclude_mask).type(torch.float32).to(self.device) + ) + self.include_mask = ( + torch.tensor(include_mask).type(torch.float32).to(self.device) + ) + self.subgraph_build_func = subgraph_build_func + + def __len__(self): + return self.exclude_mask.shape[0] + + def __getitem__(self, idx): + exclude_graph_X, exclude_graph_edge_index = self.subgraph_build_func( + self.X, self.edge_index, self.exclude_mask[idx] + ) + include_graph_X, include_graph_edge_index = self.subgraph_build_func( + self.X, self.edge_index, self.include_mask[idx] + ) + exclude_data = Data(x=exclude_graph_X, edge_index=exclude_graph_edge_index) + include_data = Data(x=include_graph_X, edge_index=include_graph_edge_index) + return exclude_data, include_data + + +def marginal_contribution( + data: Data, + exclude_mask: np.array, + include_mask: np.array, + value_func, + subgraph_build_func, +): + """Calculate the marginal value for each pair. Here exclude_mask and include_mask are node mask.""" + marginal_subgraph_dataset = MarginalSubgraphDataset( + data, exclude_mask, include_mask, subgraph_build_func + ) + dataloader = DataLoader( + marginal_subgraph_dataset, batch_size=256, shuffle=False, num_workers=0 + ) + + marginal_contribution_list = [] + + for exclude_data, include_data in dataloader: + exclude_values = value_func(exclude_data) + include_values = value_func(include_data) + margin_values = include_values - exclude_values + marginal_contribution_list.append(margin_values) + + marginal_contributions = torch.cat(marginal_contribution_list, dim=0) + return marginal_contributions + + +def graph_build_zero_filling(X, edge_index, node_mask: np.array): + """subgraph building through masking the unselected nodes with zero features""" + ret_X = X * node_mask.unsqueeze(1) + return ret_X, edge_index + + +def graph_build_split(X, edge_index, node_mask: np.array): + """subgraph building through spliting the selected nodes from the original graph""" + ret_X = X + row, col = edge_index + edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1) + ret_edge_index = edge_index[:, edge_mask] + return ret_X, ret_edge_index + + +def l_shapley( + coalition: list, + data: Data, + local_radius: int, + value_func: str, + subgraph_building_method="zero_filling", +): + """shapley value where players are local neighbor nodes""" + graph = to_networkx(data) + num_nodes = graph.number_of_nodes() + subgraph_build_func = get_graph_build_func(subgraph_building_method) + + local_region = copy.copy(coalition) + for k in range(local_radius - 1): + k_neiborhoood = [] + for node in local_region: + k_neiborhoood += list(graph.neighbors(node)) + local_region += k_neiborhoood + local_region = list(set(local_region)) + + set_exclude_masks = [] + set_include_masks = [] + nodes_around = [node for node in local_region if node not in coalition] + num_nodes_around = len(nodes_around) + + for subset_len in range(0, num_nodes_around + 1): + node_exclude_subsets = combinations(nodes_around, subset_len) + for node_exclude_subset in node_exclude_subsets: + set_exclude_mask = np.ones(num_nodes) + set_exclude_mask[local_region] = 0.0 + if node_exclude_subset: + set_exclude_mask[list(node_exclude_subset)] = 1.0 + set_include_mask = set_exclude_mask.copy() + set_include_mask[coalition] = 1.0 + + set_exclude_masks.append(set_exclude_mask) + set_include_masks.append(set_include_mask) + + exclude_mask = np.stack(set_exclude_masks, axis=0) + include_mask = np.stack(set_include_masks, axis=0) + num_players = len(nodes_around) + 1 + num_player_in_set = ( + num_players - 1 + len(coalition) - (1 - exclude_mask).sum(axis=1) + ) + p = num_players + S = num_player_in_set + coeffs = torch.tensor(1.0 / comb(p, S) / (p - S + 1e-6)) + + marginal_contributions = marginal_contribution( + data, exclude_mask, include_mask, value_func, subgraph_build_func + ) + + l_shapley_value = (marginal_contributions.squeeze().cpu() * coeffs).sum().item() + return l_shapley_value + + +def mc_shapley( + coalition: list, + data: Data, + value_func: str, + subgraph_building_method="zero_filling", + sample_num=1000, +) -> float: + """monte carlo sampling approximation of the shapley value""" + subset_build_func = get_graph_build_func(subgraph_building_method) + + num_nodes = data.num_nodes + node_indices = np.arange(num_nodes) + coalition_placeholder = num_nodes + set_exclude_masks = [] + set_include_masks = [] + + for example_idx in range(sample_num): + subset_nodes_from = [node for node in node_indices if node not in coalition] + random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) + random_nodes_permutation = np.random.permutation(random_nodes_permutation) + split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] + selected_nodes = random_nodes_permutation[:split_idx] + set_exclude_mask = np.zeros(num_nodes) + set_exclude_mask[selected_nodes] = 1.0 + set_include_mask = set_exclude_mask.copy() + set_include_mask[coalition] = 1.0 + + set_exclude_masks.append(set_exclude_mask) + set_include_masks.append(set_include_mask) + + exclude_mask = np.stack(set_exclude_masks, axis=0) + include_mask = np.stack(set_include_masks, axis=0) + marginal_contributions = marginal_contribution( + data, exclude_mask, include_mask, value_func, subset_build_func + ) + mc_shapley_value = marginal_contributions.mean().item() + + return mc_shapley_value + + +def mc_l_shapley( + coalition: list, + data: Data, + local_radius: int, + value_func: str, + subgraph_building_method="zero_filling", + sample_num=1000, +) -> float: + """monte carlo sampling approximation of the l_shapley value""" + graph = to_networkx(data) + num_nodes = graph.number_of_nodes() + subgraph_build_func = get_graph_build_func(subgraph_building_method) + + local_region = copy.copy(coalition) + for k in range(local_radius - 1): + k_neiborhoood = [] + for node in local_region: + k_neiborhoood += list(graph.neighbors(node)) + local_region += k_neiborhoood + local_region = list(set(local_region)) + + coalition_placeholder = num_nodes + set_exclude_masks = [] + set_include_masks = [] + for example_idx in range(sample_num): + subset_nodes_from = [node for node in local_region if node not in coalition] + random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) + random_nodes_permutation = np.random.permutation(random_nodes_permutation) + split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] + selected_nodes = random_nodes_permutation[:split_idx] + set_exclude_mask = np.ones(num_nodes) + set_exclude_mask[local_region] = 0.0 + set_exclude_mask[selected_nodes] = 1.0 + set_include_mask = set_exclude_mask.copy() + set_include_mask[coalition] = 1.0 + + set_exclude_masks.append(set_exclude_mask) + set_include_masks.append(set_include_mask) + + exclude_mask = np.stack(set_exclude_masks, axis=0) + include_mask = np.stack(set_include_masks, axis=0) + marginal_contributions = marginal_contribution( + data, exclude_mask, include_mask, value_func, subgraph_build_func + ) + + mc_l_shapley_value = (marginal_contributions).mean().item() + return mc_l_shapley_value + + +def gnn_score( + coalition: list, + data: Data, + value_func: str, + subgraph_building_method="zero_filling", +) -> torch.Tensor: + """the value of subgraph with selected nodes""" + num_nodes = data.num_nodes + subgraph_build_func = get_graph_build_func(subgraph_building_method) + mask = torch.zeros(num_nodes).type(torch.float32).to(data.x.device) + mask[coalition] = 1.0 + ret_x, ret_edge_index = subgraph_build_func(data.x, data.edge_index, mask) + mask_data = Data(x=ret_x, edge_index=ret_edge_index) + mask_data = Batch.from_data_list([mask_data]) + score = value_func(mask_data) + # get the score of predicted class for graph or specific node idx + return score.item() + + +def NC_mc_l_shapley( + coalition: list, + data: Data, + local_radius: int, + value_func: str, + node_idx: int = -1, + subgraph_building_method="zero_filling", + sample_num=1000, +) -> float: + """monte carlo approximation of l_shapley where the target node is kept in both subgraph""" + graph = to_networkx(data) + num_nodes = graph.number_of_nodes() + subgraph_build_func = get_graph_build_func(subgraph_building_method) + + local_region = copy.copy(coalition) + for k in range(local_radius - 1): + k_neiborhoood = [] + for node in local_region: + k_neiborhoood += list(graph.neighbors(node)) + local_region += k_neiborhoood + local_region = list(set(local_region)) + + coalition_placeholder = num_nodes + set_exclude_masks = [] + set_include_masks = [] + for example_idx in range(sample_num): + subset_nodes_from = [node for node in local_region if node not in coalition] + random_nodes_permutation = np.array(subset_nodes_from + [coalition_placeholder]) + random_nodes_permutation = np.random.permutation(random_nodes_permutation) + split_idx = np.where(random_nodes_permutation == coalition_placeholder)[0][0] + selected_nodes = random_nodes_permutation[:split_idx] + set_exclude_mask = np.ones(num_nodes) + set_exclude_mask[local_region] = 0.0 + set_exclude_mask[selected_nodes] = 1.0 + if node_idx != -1: + set_exclude_mask[node_idx] = 1.0 + set_include_mask = set_exclude_mask.copy() + set_include_mask[coalition] = 1.0 # include the node_idx + + set_exclude_masks.append(set_exclude_mask) + set_include_masks.append(set_include_mask) + + exclude_mask = np.stack(set_exclude_masks, axis=0) + include_mask = np.stack(set_include_masks, axis=0) + marginal_contributions = marginal_contribution( + data, exclude_mask, include_mask, value_func, subgraph_build_func + ) + + mc_l_shapley_value = (marginal_contributions).mean().item() + return mc_l_shapley_value + + +def sparsity(coalition: list, data: Data, subgraph_building_method="zero_filling"): + if subgraph_building_method == "zero_filling": + return 1.0 - len(coalition) / data.num_nodes + + elif subgraph_building_method == "split": + row, col = data.edge_index + node_mask = torch.zeros(data.x.shape[0]) + node_mask[coalition] = 1.0 + edge_mask = (node_mask[row] == 1) & (node_mask[col] == 1) + return 1.0 - (edge_mask.sum() / edge_mask.shape[0]).item() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0f8ed2a --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup( + name="eixgnn", + version="0.1", + description="Official implementation of EiXGNN algorithm for explaining graph neural networks", + packages=["eixgnn"], + zip_safe=False, +)