diff --git a/explaining_framework/utils/explaining/load_model.py b/explaining_framework/utils/explaining/load_model.py index 8f506c9..24c957a 100644 --- a/explaining_framework/utils/explaining/load_model.py +++ b/explaining_framework/utils/explaining/load_model.py @@ -1,14 +1,16 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import glob +import json import logging import os -import json -import glob import torch +from explaining_framework.utils.io import read_yaml from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule +from torch_geometric.graphgym.utils.io import json_to_dict_list MODEL_STATE = "model_state" OPTIMIZER_STATE = "optimizer_state" @@ -30,12 +32,31 @@ def load_ckpt( return model -def load_best_given_exp(path_to_xp:str, wrt_metric:str:'val')->str: - path = os.path.normpath(path) - path.split(os.sep) - for path in glob.glob(os.path.join(path_to_xp,'[0-9]'*10,wrt_metric,'stats.json')): - print(path) +PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/test_cifar/" + +FOLDER = "graph_classif_base-dataset=PCBA-l_mp=2-l_post_mp=3-dim_inner=64-layer_type=gatconv-graph_pooling=mean" +def xp_accuracies(path_to_xp: str, wrt_metric: str = "val") -> str: + acc = [] + for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")): + stats = json_to_dict_list(path) + for stat in stats: + acc.append( + {"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]} + ) + return acc + # return sorted(acc, key=lambda item: item["accuracy"]) +def best_ckpt_path(dataset_name: str, models_dir_path) -> str: + paths = [] + for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")): + file = read_yaml(path) + dataset_name_ = file["dataset"]["name"] + if dataset_name == dataset_name_: + paths.append(os.path.dirname(path)) + return paths + + +print(load_best_given_exp(PATH))