Upload to github
This commit is contained in:
parent
c4f9086fb0
commit
352edb7df6
30
README.md
30
README.md
|
@ -3,15 +3,15 @@ Here is the an example code for using ScoreCAM GNN from the [ScoreCAM GNN : a ge
|
||||||
```python
|
```python
|
||||||
from torch_geometric.datasets import TUDataset
|
from torch_geometric.datasets import TUDataset
|
||||||
|
|
||||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||||
data = dataset[0]
|
data = dataset[0]
|
||||||
from scgnn.scgnn import SCGNN
|
from scgnn.scgnn import SCGNN
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||||
|
|
||||||
|
|
||||||
model = Sequential(
|
model = Sequential(
|
||||||
"data",
|
"data",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
|
@ -22,15 +22,15 @@ from torch_geometric.datasets import TUDataset
|
||||||
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
|
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
|
||||||
(global_mean_pool, "x, batch -> x"),
|
(global_mean_pool, "x, batch -> x"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
data = dataset[0].to(device)
|
data = dataset[0].to(device)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||||
model.eval()
|
model.eval()
|
||||||
out = model(data)
|
out = model(data)
|
||||||
explainer = SCGNN()
|
explainer = SCGNN()
|
||||||
explained = explainer.forward(
|
explained = explainer.forward(
|
||||||
model,
|
model,
|
||||||
data.x,
|
data.x,
|
||||||
data.edge_index,
|
data.edge_index,
|
||||||
|
|
Loading…
Reference in New Issue