New features

This commit is contained in:
araison 2022-12-27 11:28:55 +01:00
parent d52bc702a2
commit e1b2a64bd6
1 changed files with 100 additions and 0 deletions

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import glob
import json
import logging
import os
import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list
MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state"
def _load_ckpt(
model: torch.nn.Module,
ckpt_path: str,
) -> torch.nn.Module:
r"""Loads the model at given checkpoint."""
if not osp.exists(path):
return None
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt[MODEL_STATE])
return model
def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str:
acc = []
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
stats = json_to_dict_list(path)
for stat in stats:
acc.append(
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
)
return acc
def xp_parser_dataset(dataset_name: str, models_dir_path) -> str:
paths = []
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
file = read_yaml(path)
dataset_name_ = file["dataset"]["name"]
if dataset_name == dataset_name_:
paths.append(os.path.dirname(path))
return paths
def best_xp_ckpt(paths, which: str = "best"):
acc = []
for path in paths:
accuracies = xp_stats(path)
acc.extend(accuracies)
acc = sorted(acc, key=lambda item: item["accuracy"])
if which == "best":
return acc[-1]
elif which == "worst":
return acc[0]
def stats_to_ckpt(parse):
paths = os.path.join(
os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt"
)
ckpts = []
for path in glob.glob(paths):
feat = os.path.basename(path)
feat = feat.split("-")
for fe in feat:
if "epoch" in fe:
fe_ep = int(fe.split("=")[1])
ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])})
return sorted(ckpts, key=lambda item: item["div"])[0]
def stats_to_cfg(parse):
path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(parse["path"])))
)
path = os.path.join(path, "config.yaml")
if os.path.exists(path):
return path
else:
raise FileNotFoundError(f"{path} does not exists")
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/"
best = best_xp_ckpt(
paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst"
)
print(best)
print(stats_to_ckpt(best))
print(stats_to_cfg(best))