New features
This commit is contained in:
parent
495cde4c70
commit
5224207466
1 changed files with 28 additions and 7 deletions
|
@ -1,14 +1,16 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import glob
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
MODEL_STATE = "model_state"
|
MODEL_STATE = "model_state"
|
||||||
OPTIMIZER_STATE = "optimizer_state"
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
|
@ -30,12 +32,31 @@ def load_ckpt(
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_best_given_exp(path_to_xp:str, wrt_metric:str:'val')->str:
|
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/test_cifar/"
|
||||||
path = os.path.normpath(path)
|
|
||||||
path.split(os.sep)
|
FOLDER = "graph_classif_base-dataset=PCBA-l_mp=2-l_post_mp=3-dim_inner=64-layer_type=gatconv-graph_pooling=mean"
|
||||||
for path in glob.glob(os.path.join(path_to_xp,'[0-9]'*10,wrt_metric,'stats.json')):
|
|
||||||
print(path)
|
|
||||||
|
|
||||||
|
|
||||||
|
def xp_accuracies(path_to_xp: str, wrt_metric: str = "val") -> str:
|
||||||
|
acc = []
|
||||||
|
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
||||||
|
stats = json_to_dict_list(path)
|
||||||
|
for stat in stats:
|
||||||
|
acc.append(
|
||||||
|
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
||||||
|
)
|
||||||
|
return acc
|
||||||
|
# return sorted(acc, key=lambda item: item["accuracy"])
|
||||||
|
|
||||||
|
|
||||||
|
def best_ckpt_path(dataset_name: str, models_dir_path) -> str:
|
||||||
|
paths = []
|
||||||
|
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
||||||
|
file = read_yaml(path)
|
||||||
|
dataset_name_ = file["dataset"]["name"]
|
||||||
|
if dataset_name == dataset_name_:
|
||||||
|
paths.append(os.path.dirname(path))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
print(load_best_given_exp(PATH))
|
||||||
|
|
Loading…
Add table
Reference in a new issue