📚 Official implementation of ScoreCAM GNN - torch-geometric implementation
Go to file
araison e66a0256ec Update 2024-04-28 16:25:32 +02:00
docs Updating 2024-04-28 16:09:23 +02:00
scgnn Update 2023-03-08 20:01:07 +01:00
test Updating 2024-04-28 16:09:23 +02:00
.coveragerc Updating 2024-04-28 16:09:23 +02:00
.gitignore Updating 2024-04-28 16:09:23 +02:00
AUTHORS.md Update 2024-04-28 16:25:32 +02:00
CHANGELOG.md Updating 2024-04-28 16:09:23 +02:00
CONTRIBUTING.md Updating 2024-04-28 16:09:23 +02:00
LICENCE Updating 2024-04-28 16:09:23 +02:00
Makefile Updating 2024-04-28 16:09:23 +02:00
README.md Update 2024-04-28 16:11:13 +02:00
setup.cfg Updating 2024-04-28 16:09:23 +02:00
setup.py Upload to github 2023-03-08 18:59:37 +01:00

README.md

ScoreCAM GNN

Official implementation of ScoreCAM GNN algorithm. For further informations, see ScoreCAM GNN : a generalization of an optimal local post-hoc explaining method to any geometric deep learning models, paper

Run an example

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
data = dataset[0]
from scgnn.scgnn import SCGNN

import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool


model = Sequential(
        "data",
        [
            (
                lambda data: (data.x, data.edge_index, data.batch),
                "data -> x, edge_index, batch",
            ),
            (GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
            (GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
            (global_mean_pool, "x, batch -> x"),
        ],
        )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.eval()
out = model(data)
explainer = SCGNN()
explained = explainer.forward(
        model,
        data.x,
        data.edge_index,
        target=2,
        interest_map_norm=True,
        score_map_norm=True,
    )