Update
This commit is contained in:
parent
352edb7df6
commit
9d4a132b25
|
@ -159,55 +159,8 @@ class SCGNN(ExplainerAlgorithm):
|
|||
score_map = score_map + val * feat
|
||||
|
||||
score_map = F.relu(score_map)
|
||||
|
||||
if self.score_map_norm and score_map.min() != score_map.max():
|
||||
score_map = (score_map - score_map.min()).div(
|
||||
score_map.max() - score_map.min()
|
||||
)
|
||||
return score_map
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch_geometric.datasets import TUDataset
|
||||
|
||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
||||
data = dataset[0]
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
|
||||
# model = torch.nn.ModuleDict(
|
||||
# {
|
||||
# "conv1": GCNConv(dataset.num_node_features, 64),
|
||||
# "conv2": GCNConv(64, dataset.num_classes),
|
||||
# "gmp": global_mean_pool,
|
||||
# }
|
||||
# )
|
||||
|
||||
model = Sequential(
|
||||
"data",
|
||||
[
|
||||
(
|
||||
lambda data: (data.x, data.edge_index, data.batch),
|
||||
"data -> x, edge_index, batch",
|
||||
),
|
||||
(GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
|
||||
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
|
||||
(global_mean_pool, "x, batch -> x"),
|
||||
],
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
data = dataset[0].to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||
model.eval()
|
||||
out = model(data)
|
||||
explainer = SCGNN()
|
||||
explained = explainer.forward(
|
||||
model,
|
||||
data.x,
|
||||
data.edge_index,
|
||||
target=2,
|
||||
interest_map_norm=True,
|
||||
score_map_norm=True,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue