Improving config generation

This commit is contained in:
araison 2023-01-09 15:04:36 +01:00
parent 1e8cc292e1
commit b1093eb528

View File

@ -2,9 +2,10 @@ import os
from explaining_framework.utils.io import write_yaml
from torch_geometric.data.makedirs import makedirs
from torch_geometric.graphgym.loader import load_pyg_dataset
def divide_chunks(l, n):
def chunkizing_list(l, n):
for i in range(0, len(l), n):
yield l[i : i + n]
@ -19,7 +20,60 @@ if "__main__" == __name__:
explainer_folder = os.path.join(config_folder, "explaining")
makedirs(explainer_folder)
DATASET = ["CIFAR10"]
DATASET = [
"CIFAR10",
"TRIANGLES",
"COLORS-3",
"REDDIT-BINARY",
"REDDIT-MULTI-5K",
"REDDIT-MULTI-12K",
"COLLAB",
"DBLP_v1",
"COIL-DEL",
"COIL-RAG",
"Fingerprint",
"Letter-high",
"Letter-low",
"Letter-med",
"MSRC_9",
"MSRC_21",
"MSRC_21C",
"DD",
"ENZYMES",
"PROTEINS",
"QM9",
"MUTAG",
"Mutagenicity",
"AIDS",
"PATTERN",
"CLUSTER",
"MNIST",
"CIFAR10",
"TSP",
"CSL",
"KarateClub",
"CS",
"Physics",
"BBBP",
"Tox21",
"HIV",
"PCBA",
"MUV",
"BACE",
"SIDER",
"ClinTox",
"AIFB",
"AM",
"MUTAG",
"BGS",
"FAUST",
"DynamicFAUST",
"ShapeNet",
"ModelNet10",
"ModelNet40",
"PascalVOC-SP",
"COCO-SP",
]
EXPLAINER = [
"CAM",
"GradCAM",
@ -27,7 +81,7 @@ if "__main__" == __name__:
"GradExplainer",
"GuidedBackPropagation",
"IntegratedGradients",
"PGExplainer",
# "PGExplainer",
"PGMExplainer",
"RandomExplainer",
"SubgraphX",
@ -38,11 +92,11 @@ if "__main__" == __name__:
]
for dataset_name in DATASET:
for chunk in divide_chunks(list(range(10000)), 500):
dataset = load_pyg_dataset(name=dataset_name, dataset_dir="/tmp/")
for chunk in chunkizing_list(list(range(len(dataset))), 300):
for model_kind in ["best", "worst"]:
for explainer_name in EXPLAINER:
explaining_cfg = {}
# explaining_cfg['adjust']['strategy']= 'rpns'
# explaining_cfg['attack']['name']= 'all'
explaining_cfg[