From b1093eb528c5f2d5dddebe984b412a075da2b4ba Mon Sep 17 00:00:00 2001 From: araison Date: Mon, 9 Jan 2023 15:04:36 +0100 Subject: [PATCH] Improving config generation --- explaining_framework/utils/config_gen.py | 64 ++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/explaining_framework/utils/config_gen.py b/explaining_framework/utils/config_gen.py index 4dd0275..0fdd6a8 100644 --- a/explaining_framework/utils/config_gen.py +++ b/explaining_framework/utils/config_gen.py @@ -2,9 +2,10 @@ import os from explaining_framework.utils.io import write_yaml from torch_geometric.data.makedirs import makedirs +from torch_geometric.graphgym.loader import load_pyg_dataset -def divide_chunks(l, n): +def chunkizing_list(l, n): for i in range(0, len(l), n): yield l[i : i + n] @@ -19,7 +20,60 @@ if "__main__" == __name__: explainer_folder = os.path.join(config_folder, "explaining") makedirs(explainer_folder) - DATASET = ["CIFAR10"] + DATASET = [ + "CIFAR10", + "TRIANGLES", + "COLORS-3", + "REDDIT-BINARY", + "REDDIT-MULTI-5K", + "REDDIT-MULTI-12K", + "COLLAB", + "DBLP_v1", + "COIL-DEL", + "COIL-RAG", + "Fingerprint", + "Letter-high", + "Letter-low", + "Letter-med", + "MSRC_9", + "MSRC_21", + "MSRC_21C", + "DD", + "ENZYMES", + "PROTEINS", + "QM9", + "MUTAG", + "Mutagenicity", + "AIDS", + "PATTERN", + "CLUSTER", + "MNIST", + "CIFAR10", + "TSP", + "CSL", + "KarateClub", + "CS", + "Physics", + "BBBP", + "Tox21", + "HIV", + "PCBA", + "MUV", + "BACE", + "SIDER", + "ClinTox", + "AIFB", + "AM", + "MUTAG", + "BGS", + "FAUST", + "DynamicFAUST", + "ShapeNet", + "ModelNet10", + "ModelNet40", + "PascalVOC-SP", + "COCO-SP", + ] EXPLAINER = [ "CAM", "GradCAM", @@ -27,7 +81,7 @@ if "__main__" == __name__: "GradExplainer", "GuidedBackPropagation", "IntegratedGradients", - "PGExplainer", + # "PGExplainer", "PGMExplainer", "RandomExplainer", "SubgraphX", @@ -38,11 +92,11 @@ if "__main__" == __name__: ] for dataset_name in DATASET: - for chunk in divide_chunks(list(range(10000)), 500): + dataset = load_pyg_dataset(name=dataset_name, dataset_dir="/tmp/") + for chunk in chunkizing_list(list(range(len(dataset))), 300): for model_kind in ["best", "worst"]: for explainer_name in EXPLAINER: explaining_cfg = {} - # explaining_cfg['adjust']['strategy']= 'rpns' # explaining_cfg['attack']['name']= 'all' explaining_cfg[