From 3109a36e6e27aa4eb3c1b0ad6a581ab51af8fea5 Mon Sep 17 00:00:00 2001 From: araison Date: Mon, 6 Mar 2023 11:46:59 +0100 Subject: [PATCH] Update README.md --- README.md | 33 ++++++++++++++ eixgnn/eixgnn.py | 115 +---------------------------------------------- 2 files changed, 34 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index e69de29..d84818b 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,33 @@ +Here is the an example code for using EiXGNN from the `"EiX-GNN: Concept-level eigencentrality explainer for graph neural + networks"`_ paper + +```python +from torch_geometric.datasets import TUDataset + +dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") +data = dataset[0] + +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, global_mean_pool + +class GCN(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(dataset.num_node_features, 20) + self.conv2 = GCNConv(20, dataset.num_classes) + + def forward(self, data): + x, edge_index, batch = data.x, data.edge_index, data.batch + x = self.conv1(x, edge_index) + x = F.relu(x) + x = global_mean_pool(x, batch) + x = F.softmax(x, dim=1) + return x + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = GCN().to(device) +model.eval() +data = dataset[0].to(device) +explainer = EiXGNN() +explained = explainer.forward(model, data.x, data.edge_index) +``` diff --git a/eixgnn/eixgnn.py b/eixgnn/eixgnn.py index f1065fd..4a860b7 100644 --- a/eixgnn/eixgnn.py +++ b/eixgnn/eixgnn.py @@ -11,7 +11,6 @@ from torch.nn import KLDivLoss, Softmax from torch_geometric.data import Batch, Data from torch_geometric.explain import Explanation from torch_geometric.explain.algorithm.base import ExplainerAlgorithm -# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks from torch_geometric.explain.config import (ExplainerConfig, MaskType, ModelConfig, ModelMode, ModelTaskLevel) @@ -129,7 +128,7 @@ class EiXGNN(ExplainerAlgorithm): f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided." ) return False - # ADD OTHER CASE + # TODO ADD OTHER CASE return True def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor: @@ -216,16 +215,6 @@ class EiXGNN(ExplainerAlgorithm): extended_map[i, indexes[i][j]] += abs(shap_vals[i][j]) return extended_map - # def _shapley_value_extended(self, concepts, shap_val, size=None): - # extended_shap_val_matrix = np.zeros(size) - # nodes_lists = [self._subgraph_node_mapping(concept) for concept in concepts] - # for i in range(extended_shap_val_matrix.shape[0]): - # concept_shap_val = shap_val[i] - # node_list = nodes_lists[i] - # for j, val in zip(node_list, concept_shap_val): - # extended_shap_val_matrix[i, j] = val - # return extended_shap_val_matrix - def _global_concept_similarity_matrix(self, concepts): A = np.zeros((len(concepts), len(concepts))) concepts_pred = self.model(Batch().from_data_list(concepts)) @@ -249,80 +238,6 @@ class EiXGNN(ExplainerAlgorithm): pr = list(pagerank(G).values()) return pr - # def _shapley_value_concept( - # self, - # concept: Data, - # ) -> np.ndarray: - - # g_plus_list = [] - # g_minus_list = [] - # val_shap = np.zeros(concept.x.shape[0]) - - # for node_ind in range(concept.x.shape[0]): - # for _ in range(self.shap_val_approx): - # perm = torch.randperm(concept.x.shape[0]) - - # x_perm = torch.clone(concept.x) - # x_perm = x_perm[perm] - - # x_minus = torch.clone(concept.x) - # x_plus = torch.clone(concept.x) - - # x_plus = torch.cat( - # (x_plus[: node_ind + 1], x_perm[node_ind + 1 :]), axis=0 - # ) - # if node_ind == 0: - # x_minus = torch.clone(x_perm) - # else: - # x_minus = torch.cat((x_minus[:node_ind], x_perm[node_ind:]), axis=0) - - # g_plus = concept.__copy__() - # g_plus.x = x_plus - # g_minus = concept.__copy__() - # g_minus.x = x_minus - - # g_plus_list.append(g_plus) - # g_minus_list.append(g_minus) - - # g_plus = Batch().from_data_list(g_plus_list) - # g_minus = Batch().from_data_list(g_minus_list) - # with torch.no_grad(): - # g_plus = g_plus.to(concept.x.device) - # g_minus = g_minus.to(concept.x.device) - - # out_g_plus = self.model(g_plus) - # out_g_minus = self.model(g_minus) - - # g_plus_score = F.softmax(out_g_plus, dim=1) - # g_minus_score = F.softmax(out_g_minus, dim=1) - - # g_plus_score = g_plus_score.reshape( - # (1, concept.x.shape[0], self.shap_val_approx, -1) - # ) - # g_minus_score = g_minus_score.reshape( - # (1, concept.x.shape[0], self.shap_val_approx, -1) - # ) - - # for node_ind in range(concept.x.shape[0]): - # score = torch.mean( - # torch.FloatTensor( - # [ - # torch.norm( - # input=g_plus_score[0, node_ind, i] - # - g_minus_score[0, node_ind, i], - # p=1, - # ) - # for i in range(self.shap_val_approx) - # ] - # ) - # ) - # val_shap[node_ind] = score.item() - # return val_shap - - # def _subgraph_node_mapping(self, concept): - # nodes_list = torch.unique(concept.edge_index) - # return nodes_list - def _compute_node_ablation_prior( self, x: torch.Tensor, @@ -353,32 +268,4 @@ class EiXGNN(ExplainerAlgorithm): if __name__ == "__main__": - from torch_geometric.datasets import TUDataset - dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") - data = dataset[0] - - import torch.nn.functional as F - from torch_geometric.nn import GCNConv, global_mean_pool - - class GCN(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = GCNConv(dataset.num_node_features, 20) - self.conv2 = GCNConv(20, dataset.num_classes) - - def forward(self, data): - x, edge_index, batch = data.x, data.edge_index, data.batch - x = self.conv1(x, edge_index) - x = F.relu(x) - x = global_mean_pool(x, batch) - x = F.softmax(x, dim=1) - return x - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = GCN().to(device) - model.eval() - data = dataset[0].to(device) - # optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) - # explainer = EiXGNN() - # explained = explainer.forward(model, data.x, data.edge_index)