Update README.md

This commit is contained in:
araison 2023-03-06 11:46:59 +01:00
parent 5c6cc20870
commit 3109a36e6e
2 changed files with 34 additions and 114 deletions

View File

@ -0,0 +1,33 @@
Here is the an example code for using EiXGNN from the `"EiX-GNN: Concept-level eigencentrality explainer for graph neural
networks"<https://arxiv.org/abs/2206.03491>`_ paper
```python
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
data = dataset[0]
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 20)
self.conv2 = GCNConv(20, dataset.num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
x = F.softmax(x, dim=1)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
model.eval()
data = dataset[0].to(device)
explainer = EiXGNN()
explained = explainer.forward(model, data.x, data.edge_index)
```

View File

@ -11,7 +11,6 @@ from torch.nn import KLDivLoss, Softmax
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
ModelConfig, ModelMode,
ModelTaskLevel)
@ -129,7 +128,7 @@ class EiXGNN(ExplainerAlgorithm):
f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided."
)
return False
# ADD OTHER CASE
# TODO ADD OTHER CASE
return True
def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor:
@ -216,16 +215,6 @@ class EiXGNN(ExplainerAlgorithm):
extended_map[i, indexes[i][j]] += abs(shap_vals[i][j])
return extended_map
# def _shapley_value_extended(self, concepts, shap_val, size=None):
# extended_shap_val_matrix = np.zeros(size)
# nodes_lists = [self._subgraph_node_mapping(concept) for concept in concepts]
# for i in range(extended_shap_val_matrix.shape[0]):
# concept_shap_val = shap_val[i]
# node_list = nodes_lists[i]
# for j, val in zip(node_list, concept_shap_val):
# extended_shap_val_matrix[i, j] = val
# return extended_shap_val_matrix
def _global_concept_similarity_matrix(self, concepts):
A = np.zeros((len(concepts), len(concepts)))
concepts_pred = self.model(Batch().from_data_list(concepts))
@ -249,80 +238,6 @@ class EiXGNN(ExplainerAlgorithm):
pr = list(pagerank(G).values())
return pr
# def _shapley_value_concept(
# self,
# concept: Data,
# ) -> np.ndarray:
# g_plus_list = []
# g_minus_list = []
# val_shap = np.zeros(concept.x.shape[0])
# for node_ind in range(concept.x.shape[0]):
# for _ in range(self.shap_val_approx):
# perm = torch.randperm(concept.x.shape[0])
# x_perm = torch.clone(concept.x)
# x_perm = x_perm[perm]
# x_minus = torch.clone(concept.x)
# x_plus = torch.clone(concept.x)
# x_plus = torch.cat(
# (x_plus[: node_ind + 1], x_perm[node_ind + 1 :]), axis=0
# )
# if node_ind == 0:
# x_minus = torch.clone(x_perm)
# else:
# x_minus = torch.cat((x_minus[:node_ind], x_perm[node_ind:]), axis=0)
# g_plus = concept.__copy__()
# g_plus.x = x_plus
# g_minus = concept.__copy__()
# g_minus.x = x_minus
# g_plus_list.append(g_plus)
# g_minus_list.append(g_minus)
# g_plus = Batch().from_data_list(g_plus_list)
# g_minus = Batch().from_data_list(g_minus_list)
# with torch.no_grad():
# g_plus = g_plus.to(concept.x.device)
# g_minus = g_minus.to(concept.x.device)
# out_g_plus = self.model(g_plus)
# out_g_minus = self.model(g_minus)
# g_plus_score = F.softmax(out_g_plus, dim=1)
# g_minus_score = F.softmax(out_g_minus, dim=1)
# g_plus_score = g_plus_score.reshape(
# (1, concept.x.shape[0], self.shap_val_approx, -1)
# )
# g_minus_score = g_minus_score.reshape(
# (1, concept.x.shape[0], self.shap_val_approx, -1)
# )
# for node_ind in range(concept.x.shape[0]):
# score = torch.mean(
# torch.FloatTensor(
# [
# torch.norm(
# input=g_plus_score[0, node_ind, i]
# - g_minus_score[0, node_ind, i],
# p=1,
# )
# for i in range(self.shap_val_approx)
# ]
# )
# )
# val_shap[node_ind] = score.item()
# return val_shap
# def _subgraph_node_mapping(self, concept):
# nodes_list = torch.unique(concept.edge_index)
# return nodes_list
def _compute_node_ablation_prior(
self,
x: torch.Tensor,
@ -353,32 +268,4 @@ class EiXGNN(ExplainerAlgorithm):
if __name__ == "__main__":
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
data = dataset[0]
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 20)
self.conv2 = GCNConv(20, dataset.num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
x = F.softmax(x, dim=1)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)
model.eval()
data = dataset[0].to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# explainer = EiXGNN()
# explained = explainer.forward(model, data.x, data.edge_index)