New features
This commit is contained in:
parent
495cde4c70
commit
5224207466
@ -1,14 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
@ -30,12 +32,31 @@ def load_ckpt(
|
||||
return model
|
||||
|
||||
|
||||
def load_best_given_exp(path_to_xp:str, wrt_metric:str:'val')->str:
|
||||
path = os.path.normpath(path)
|
||||
path.split(os.sep)
|
||||
for path in glob.glob(os.path.join(path_to_xp,'[0-9]'*10,wrt_metric,'stats.json')):
|
||||
print(path)
|
||||
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/test_cifar/"
|
||||
|
||||
FOLDER = "graph_classif_base-dataset=PCBA-l_mp=2-l_post_mp=3-dim_inner=64-layer_type=gatconv-graph_pooling=mean"
|
||||
|
||||
|
||||
def xp_accuracies(path_to_xp: str, wrt_metric: str = "val") -> str:
|
||||
acc = []
|
||||
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
||||
stats = json_to_dict_list(path)
|
||||
for stat in stats:
|
||||
acc.append(
|
||||
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
||||
)
|
||||
return acc
|
||||
# return sorted(acc, key=lambda item: item["accuracy"])
|
||||
|
||||
|
||||
def best_ckpt_path(dataset_name: str, models_dir_path) -> str:
|
||||
paths = []
|
||||
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
||||
file = read_yaml(path)
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
return paths
|
||||
|
||||
|
||||
print(load_best_given_exp(PATH))
|
||||
|
Loading…
Reference in New Issue
Block a user