New features

This commit is contained in:
araison 2022-12-23 16:19:34 +01:00
parent 495cde4c70
commit 5224207466

View File

@ -1,14 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import glob
import json
import logging
import os
import json
import glob
import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list
MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state"
@ -30,12 +32,31 @@ def load_ckpt(
return model
def load_best_given_exp(path_to_xp:str, wrt_metric:str:'val')->str:
path = os.path.normpath(path)
path.split(os.sep)
for path in glob.glob(os.path.join(path_to_xp,'[0-9]'*10,wrt_metric,'stats.json')):
print(path)
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/test_cifar/"
FOLDER = "graph_classif_base-dataset=PCBA-l_mp=2-l_post_mp=3-dim_inner=64-layer_type=gatconv-graph_pooling=mean"
def xp_accuracies(path_to_xp: str, wrt_metric: str = "val") -> str:
acc = []
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
stats = json_to_dict_list(path)
for stat in stats:
acc.append(
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
)
return acc
# return sorted(acc, key=lambda item: item["accuracy"])
def best_ckpt_path(dataset_name: str, models_dir_path) -> str:
paths = []
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
file = read_yaml(path)
dataset_name_ = file["dataset"]["name"]
if dataset_name == dataset_name_:
paths.append(os.path.dirname(path))
return paths
print(load_best_given_exp(PATH))