Update README.md
This commit is contained in:
parent
5c6cc20870
commit
3109a36e6e
2 changed files with 34 additions and 114 deletions
33
README.md
33
README.md
|
@ -0,0 +1,33 @@
|
||||||
|
Here is the an example code for using EiXGNN from the `"EiX-GNN: Concept-level eigencentrality explainer for graph neural
|
||||||
|
networks"<https://arxiv.org/abs/2206.03491>`_ paper
|
||||||
|
|
||||||
|
```python
|
||||||
|
from torch_geometric.datasets import TUDataset
|
||||||
|
|
||||||
|
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||||
|
data = dataset[0]
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||||
|
|
||||||
|
class GCN(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = GCNConv(dataset.num_node_features, 20)
|
||||||
|
self.conv2 = GCNConv(20, dataset.num_classes)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
x, edge_index, batch = data.x, data.edge_index, data.batch
|
||||||
|
x = self.conv1(x, edge_index)
|
||||||
|
x = F.relu(x)
|
||||||
|
x = global_mean_pool(x, batch)
|
||||||
|
x = F.softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = GCN().to(device)
|
||||||
|
model.eval()
|
||||||
|
data = dataset[0].to(device)
|
||||||
|
explainer = EiXGNN()
|
||||||
|
explained = explainer.forward(model, data.x, data.edge_index)
|
||||||
|
```
|
115
eixgnn/eixgnn.py
115
eixgnn/eixgnn.py
|
@ -11,7 +11,6 @@ from torch.nn import KLDivLoss, Softmax
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.explain import Explanation
|
from torch_geometric.explain import Explanation
|
||||||
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
from torch_geometric.explain.algorithm.base import ExplainerAlgorithm
|
||||||
# from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
|
|
||||||
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
from torch_geometric.explain.config import (ExplainerConfig, MaskType,
|
||||||
ModelConfig, ModelMode,
|
ModelConfig, ModelMode,
|
||||||
ModelTaskLevel)
|
ModelTaskLevel)
|
||||||
|
@ -129,7 +128,7 @@ class EiXGNN(ExplainerAlgorithm):
|
||||||
f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided."
|
f"This signal similarity metric : {signal_similarity['metric']} is not supported yet. No explanation provided."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
# ADD OTHER CASE
|
# TODO ADD OTHER CASE
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor:
|
def get_mc_shapley(self, subset_node: list, data: Data) -> Tensor:
|
||||||
|
@ -216,16 +215,6 @@ class EiXGNN(ExplainerAlgorithm):
|
||||||
extended_map[i, indexes[i][j]] += abs(shap_vals[i][j])
|
extended_map[i, indexes[i][j]] += abs(shap_vals[i][j])
|
||||||
return extended_map
|
return extended_map
|
||||||
|
|
||||||
# def _shapley_value_extended(self, concepts, shap_val, size=None):
|
|
||||||
# extended_shap_val_matrix = np.zeros(size)
|
|
||||||
# nodes_lists = [self._subgraph_node_mapping(concept) for concept in concepts]
|
|
||||||
# for i in range(extended_shap_val_matrix.shape[0]):
|
|
||||||
# concept_shap_val = shap_val[i]
|
|
||||||
# node_list = nodes_lists[i]
|
|
||||||
# for j, val in zip(node_list, concept_shap_val):
|
|
||||||
# extended_shap_val_matrix[i, j] = val
|
|
||||||
# return extended_shap_val_matrix
|
|
||||||
|
|
||||||
def _global_concept_similarity_matrix(self, concepts):
|
def _global_concept_similarity_matrix(self, concepts):
|
||||||
A = np.zeros((len(concepts), len(concepts)))
|
A = np.zeros((len(concepts), len(concepts)))
|
||||||
concepts_pred = self.model(Batch().from_data_list(concepts))
|
concepts_pred = self.model(Batch().from_data_list(concepts))
|
||||||
|
@ -249,80 +238,6 @@ class EiXGNN(ExplainerAlgorithm):
|
||||||
pr = list(pagerank(G).values())
|
pr = list(pagerank(G).values())
|
||||||
return pr
|
return pr
|
||||||
|
|
||||||
# def _shapley_value_concept(
|
|
||||||
# self,
|
|
||||||
# concept: Data,
|
|
||||||
# ) -> np.ndarray:
|
|
||||||
|
|
||||||
# g_plus_list = []
|
|
||||||
# g_minus_list = []
|
|
||||||
# val_shap = np.zeros(concept.x.shape[0])
|
|
||||||
|
|
||||||
# for node_ind in range(concept.x.shape[0]):
|
|
||||||
# for _ in range(self.shap_val_approx):
|
|
||||||
# perm = torch.randperm(concept.x.shape[0])
|
|
||||||
|
|
||||||
# x_perm = torch.clone(concept.x)
|
|
||||||
# x_perm = x_perm[perm]
|
|
||||||
|
|
||||||
# x_minus = torch.clone(concept.x)
|
|
||||||
# x_plus = torch.clone(concept.x)
|
|
||||||
|
|
||||||
# x_plus = torch.cat(
|
|
||||||
# (x_plus[: node_ind + 1], x_perm[node_ind + 1 :]), axis=0
|
|
||||||
# )
|
|
||||||
# if node_ind == 0:
|
|
||||||
# x_minus = torch.clone(x_perm)
|
|
||||||
# else:
|
|
||||||
# x_minus = torch.cat((x_minus[:node_ind], x_perm[node_ind:]), axis=0)
|
|
||||||
|
|
||||||
# g_plus = concept.__copy__()
|
|
||||||
# g_plus.x = x_plus
|
|
||||||
# g_minus = concept.__copy__()
|
|
||||||
# g_minus.x = x_minus
|
|
||||||
|
|
||||||
# g_plus_list.append(g_plus)
|
|
||||||
# g_minus_list.append(g_minus)
|
|
||||||
|
|
||||||
# g_plus = Batch().from_data_list(g_plus_list)
|
|
||||||
# g_minus = Batch().from_data_list(g_minus_list)
|
|
||||||
# with torch.no_grad():
|
|
||||||
# g_plus = g_plus.to(concept.x.device)
|
|
||||||
# g_minus = g_minus.to(concept.x.device)
|
|
||||||
|
|
||||||
# out_g_plus = self.model(g_plus)
|
|
||||||
# out_g_minus = self.model(g_minus)
|
|
||||||
|
|
||||||
# g_plus_score = F.softmax(out_g_plus, dim=1)
|
|
||||||
# g_minus_score = F.softmax(out_g_minus, dim=1)
|
|
||||||
|
|
||||||
# g_plus_score = g_plus_score.reshape(
|
|
||||||
# (1, concept.x.shape[0], self.shap_val_approx, -1)
|
|
||||||
# )
|
|
||||||
# g_minus_score = g_minus_score.reshape(
|
|
||||||
# (1, concept.x.shape[0], self.shap_val_approx, -1)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# for node_ind in range(concept.x.shape[0]):
|
|
||||||
# score = torch.mean(
|
|
||||||
# torch.FloatTensor(
|
|
||||||
# [
|
|
||||||
# torch.norm(
|
|
||||||
# input=g_plus_score[0, node_ind, i]
|
|
||||||
# - g_minus_score[0, node_ind, i],
|
|
||||||
# p=1,
|
|
||||||
# )
|
|
||||||
# for i in range(self.shap_val_approx)
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
# val_shap[node_ind] = score.item()
|
|
||||||
# return val_shap
|
|
||||||
|
|
||||||
# def _subgraph_node_mapping(self, concept):
|
|
||||||
# nodes_list = torch.unique(concept.edge_index)
|
|
||||||
# return nodes_list
|
|
||||||
|
|
||||||
def _compute_node_ablation_prior(
|
def _compute_node_ablation_prior(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
@ -353,32 +268,4 @@ class EiXGNN(ExplainerAlgorithm):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch_geometric.datasets import TUDataset
|
|
||||||
|
|
||||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
|
||||||
data = dataset[0]
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
|
||||||
|
|
||||||
class GCN(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.conv1 = GCNConv(dataset.num_node_features, 20)
|
|
||||||
self.conv2 = GCNConv(20, dataset.num_classes)
|
|
||||||
|
|
||||||
def forward(self, data):
|
|
||||||
x, edge_index, batch = data.x, data.edge_index, data.batch
|
|
||||||
x = self.conv1(x, edge_index)
|
|
||||||
x = F.relu(x)
|
|
||||||
x = global_mean_pool(x, batch)
|
|
||||||
x = F.softmax(x, dim=1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
model = GCN().to(device)
|
|
||||||
model.eval()
|
|
||||||
data = dataset[0].to(device)
|
|
||||||
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
|
||||||
# explainer = EiXGNN()
|
|
||||||
# explained = explainer.forward(model, data.x, data.edge_index)
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue