Update README.md
This commit is contained in:
parent
5c6cc20870
commit
3109a36e6e
33
README.md
33
README.md
|
@ -0,0 +1,33 @@
|
|||
Here is the an example code for using EiXGNN from the `"EiX-GNN: Concept-level eigencentrality explainer for graph neural
|
||||
networks"<https://arxiv.org/abs/2206.03491>`_ paper
|
||||
|
||||
```python
|
||||
from torch_geometric.datasets import TUDataset
|
||||
|
||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||
data = dataset[0]
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
|
||||
class GCN(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(dataset.num_node_features, 20)
|
||||
self.conv2 = GCNConv(20, dataset.num_classes)
|
||||
|
||||
def forward(self, data):
|
||||
x, edge_index, batch = data.x, data.edge_index, data.batch
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.relu(x)
|
||||
x = global_mean_pool(x, batch)
|
||||
x = F.softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = GCN().to(device)
|
||||
model.eval()
|
||||
data = dataset[0].to(device)
|
||||
explainer = EiXGNN()
|
||||
explained = explainer.forward(model, data.x, data.edge_index)
|
||||
```
|
115
eixgnn/eixgnn.py
115
eixgnn/eixgnn.py
|
@ -11,7 +11,6 @@ from torch.nn import KLDivLoss, Softmax
|
|||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explanation
|
||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
||||
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||
ModelConfig, ModelMode,
|
||||
ModelTaskLevel)
|
||||
|
@ -129,7 +128,7 @@ class EiXGNN(ExplainerAlgorithm):
|
|||
f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided."
|
||||
)
|
||||
return False
|
||||
# ADD OTHER CASE
|
||||
# TODO ADD OTHER CASE
|
||||
return True
|
||||
|
||||
def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor:
|
||||
|
@ -216,16 +215,6 @@ class EiXGNN(ExplainerAlgorithm):
|
|||
extended_map[i, indexes[i][j]] += abs(shap_vals[i][j])
|
||||
return extended_map
|
||||
|
||||
# def _shapley_value_extended(self, concepts, shap_val, size=None):
|
||||
# extended_shap_val_matrix = np.zeros(size)
|
||||
# nodes_lists = [self._subgraph_node_mapping(concept) for concept in concepts]
|
||||
# for i in range(extended_shap_val_matrix.shape[0]):
|
||||
# concept_shap_val = shap_val[i]
|
||||
# node_list = nodes_lists[i]
|
||||
# for j, val in zip(node_list, concept_shap_val):
|
||||
# extended_shap_val_matrix[i, j] = val
|
||||
# return extended_shap_val_matrix
|
||||
|
||||
def _global_concept_similarity_matrix(self, concepts):
|
||||
A = np.zeros((len(concepts), len(concepts)))
|
||||
concepts_pred = self.model(Batch().from_data_list(concepts))
|
||||
|
@ -249,80 +238,6 @@ class EiXGNN(ExplainerAlgorithm):
|
|||
pr = list(pagerank(G).values())
|
||||
return pr
|
||||
|
||||
# def _shapley_value_concept(
|
||||
# self,
|
||||
# concept: Data,
|
||||
# ) -> np.ndarray:
|
||||
|
||||
# g_plus_list = []
|
||||
# g_minus_list = []
|
||||
# val_shap = np.zeros(concept.x.shape[0])
|
||||
|
||||
# for node_ind in range(concept.x.shape[0]):
|
||||
# for _ in range(self.shap_val_approx):
|
||||
# perm = torch.randperm(concept.x.shape[0])
|
||||
|
||||
# x_perm = torch.clone(concept.x)
|
||||
# x_perm = x_perm[perm]
|
||||
|
||||
# x_minus = torch.clone(concept.x)
|
||||
# x_plus = torch.clone(concept.x)
|
||||
|
||||
# x_plus = torch.cat(
|
||||
# (x_plus[: node_ind + 1], x_perm[node_ind + 1 :]), axis=0
|
||||
# )
|
||||
# if node_ind == 0:
|
||||
# x_minus = torch.clone(x_perm)
|
||||
# else:
|
||||
# x_minus = torch.cat((x_minus[:node_ind], x_perm[node_ind:]), axis=0)
|
||||
|
||||
# g_plus = concept.__copy__()
|
||||
# g_plus.x = x_plus
|
||||
# g_minus = concept.__copy__()
|
||||
# g_minus.x = x_minus
|
||||
|
||||
# g_plus_list.append(g_plus)
|
||||
# g_minus_list.append(g_minus)
|
||||
|
||||
# g_plus = Batch().from_data_list(g_plus_list)
|
||||
# g_minus = Batch().from_data_list(g_minus_list)
|
||||
# with torch.no_grad():
|
||||
# g_plus = g_plus.to(concept.x.device)
|
||||
# g_minus = g_minus.to(concept.x.device)
|
||||
|
||||
# out_g_plus = self.model(g_plus)
|
||||
# out_g_minus = self.model(g_minus)
|
||||
|
||||
# g_plus_score = F.softmax(out_g_plus, dim=1)
|
||||
# g_minus_score = F.softmax(out_g_minus, dim=1)
|
||||
|
||||
# g_plus_score = g_plus_score.reshape(
|
||||
# (1, concept.x.shape[0], self.shap_val_approx, -1)
|
||||
# )
|
||||
# g_minus_score = g_minus_score.reshape(
|
||||
# (1, concept.x.shape[0], self.shap_val_approx, -1)
|
||||
# )
|
||||
|
||||
# for node_ind in range(concept.x.shape[0]):
|
||||
# score = torch.mean(
|
||||
# torch.FloatTensor(
|
||||
# [
|
||||
# torch.norm(
|
||||
# input=g_plus_score[0, node_ind, i]
|
||||
# - g_minus_score[0, node_ind, i],
|
||||
# p=1,
|
||||
# )
|
||||
# for i in range(self.shap_val_approx)
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
# val_shap[node_ind] = score.item()
|
||||
# return val_shap
|
||||
|
||||
# def _subgraph_node_mapping(self, concept):
|
||||
# nodes_list = torch.unique(concept.edge_index)
|
||||
# return nodes_list
|
||||
|
||||
def _compute_node_ablation_prior(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
@ -353,32 +268,4 @@ class EiXGNN(ExplainerAlgorithm):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch_geometric.datasets import TUDataset
|
||||
|
||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||
data = dataset[0]
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
|
||||
class GCN(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(dataset.num_node_features, 20)
|
||||
self.conv2 = GCNConv(20, dataset.num_classes)
|
||||
|
||||
def forward(self, data):
|
||||
x, edge_index, batch = data.x, data.edge_index, data.batch
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.relu(x)
|
||||
x = global_mean_pool(x, batch)
|
||||
x = F.softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = GCN().to(device)
|
||||
model.eval()
|
||||
data = dataset[0].to(device)
|
||||
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||
# explainer = EiXGNN()
|
||||
# explained = explainer.forward(model, data.x, data.edge_index)
|
||||
|
|
Loading…
Reference in New Issue