From 9d4a132b251cfc28ed9e7420f63bff2fc3915ad8 Mon Sep 17 00:00:00 2001 From: araison Date: Wed, 8 Mar 2023 20:01:07 +0100 Subject: [PATCH] Update --- scgnn/scgnn.py | 47 ----------------------------------------------- 1 file changed, 47 deletions(-) diff --git a/scgnn/scgnn.py b/scgnn/scgnn.py index a0f07ef..b7f2cd5 100644 --- a/scgnn/scgnn.py +++ b/scgnn/scgnn.py @@ -159,55 +159,8 @@ class SCGNN(ExplainerAlgorithm): score_map = score_map + val * feat score_map = F.relu(score_map) - if self.score_map_norm and score_map.min() != score_map.max(): score_map = (score_map - score_map.min()).div( score_map.max() - score_map.min() ) return score_map - - -if __name__ == "__main__": - from torch_geometric.datasets import TUDataset - - dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") - data = dataset[0] - - import torch.nn.functional as F - from torch_geometric.nn import GCNConv, global_mean_pool - - # model = torch.nn.ModuleDict( - # { - # "conv1": GCNConv(dataset.num_node_features, 64), - # "conv2": GCNConv(64, dataset.num_classes), - # "gmp": global_mean_pool, - # } - # ) - - model = Sequential( - "data", - [ - ( - lambda data: (data.x, data.edge_index, data.batch), - "data -> x, edge_index, batch", - ), - (GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"), - (GCNConv(64, dataset.num_classes), "x, edge_index -> x"), - (global_mean_pool, "x, batch -> x"), - ], - ) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = model.to(device) - data = dataset[0].to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) - model.eval() - out = model(data) - explainer = SCGNN() - explained = explainer.forward( - model, - data.x, - data.edge_index, - target=2, - interest_map_norm=True, - score_map_norm=True, - )