Update
This commit is contained in:
parent
352edb7df6
commit
9d4a132b25
|
@ -159,55 +159,8 @@ class SCGNN(ExplainerAlgorithm):
|
||||||
score_map = score_map + val * feat
|
score_map = score_map + val * feat
|
||||||
|
|
||||||
score_map = F.relu(score_map)
|
score_map = F.relu(score_map)
|
||||||
|
|
||||||
if self.score_map_norm and score_map.min() != score_map.max():
|
if self.score_map_norm and score_map.min() != score_map.max():
|
||||||
score_map = (score_map - score_map.min()).div(
|
score_map = (score_map - score_map.min()).div(
|
||||||
score_map.max() - score_map.min()
|
score_map.max() - score_map.min()
|
||||||
)
|
)
|
||||||
return score_map
|
return score_map
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from torch_geometric.datasets import TUDataset
|
|
||||||
|
|
||||||
dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES")
|
|
||||||
data = dataset[0]
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
|
||||||
|
|
||||||
# model = torch.nn.ModuleDict(
|
|
||||||
# {
|
|
||||||
# "conv1": GCNConv(dataset.num_node_features, 64),
|
|
||||||
# "conv2": GCNConv(64, dataset.num_classes),
|
|
||||||
# "gmp": global_mean_pool,
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
|
|
||||||
model = Sequential(
|
|
||||||
"data",
|
|
||||||
[
|
|
||||||
(
|
|
||||||
lambda data: (data.x, data.edge_index, data.batch),
|
|
||||||
"data -> x, edge_index, batch",
|
|
||||||
),
|
|
||||||
(GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"),
|
|
||||||
(GCNConv(64, dataset.num_classes), "x, edge_index -> x"),
|
|
||||||
(global_mean_pool, "x, batch -> x"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
model = model.to(device)
|
|
||||||
data = dataset[0].to(device)
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
|
||||||
model.eval()
|
|
||||||
out = model(data)
|
|
||||||
explainer = SCGNN()
|
|
||||||
explained = explainer.forward(
|
|
||||||
model,
|
|
||||||
data.x,
|
|
||||||
data.edge_index,
|
|
||||||
target=2,
|
|
||||||
interest_map_norm=True,
|
|
||||||
score_map_norm=True,
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue