New features
This commit is contained in:
parent
d52bc702a2
commit
e1b2a64bd6
100
explaining_framework/utils/explaining/load_ckpt.py
Normal file
100
explaining_framework/utils/explaining/load_ckpt.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
||||
|
||||
def _load_ckpt(
|
||||
model: torch.nn.Module,
|
||||
ckpt_path: str,
|
||||
) -> torch.nn.Module:
|
||||
r"""Loads the model at given checkpoint."""
|
||||
|
||||
if not osp.exists(path):
|
||||
return None
|
||||
|
||||
ckpt = torch.load(ckpt_path)
|
||||
model.load_state_dict(ckpt[MODEL_STATE])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str:
|
||||
acc = []
|
||||
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
||||
stats = json_to_dict_list(path)
|
||||
for stat in stats:
|
||||
acc.append(
|
||||
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
||||
)
|
||||
return acc
|
||||
|
||||
|
||||
def xp_parser_dataset(dataset_name: str, models_dir_path) -> str:
|
||||
paths = []
|
||||
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
||||
file = read_yaml(path)
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
return paths
|
||||
|
||||
|
||||
def best_xp_ckpt(paths, which: str = "best"):
|
||||
acc = []
|
||||
for path in paths:
|
||||
accuracies = xp_stats(path)
|
||||
acc.extend(accuracies)
|
||||
acc = sorted(acc, key=lambda item: item["accuracy"])
|
||||
if which == "best":
|
||||
return acc[-1]
|
||||
elif which == "worst":
|
||||
return acc[0]
|
||||
|
||||
|
||||
def stats_to_ckpt(parse):
|
||||
paths = os.path.join(
|
||||
os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt"
|
||||
)
|
||||
ckpts = []
|
||||
for path in glob.glob(paths):
|
||||
feat = os.path.basename(path)
|
||||
feat = feat.split("-")
|
||||
for fe in feat:
|
||||
if "epoch" in fe:
|
||||
fe_ep = int(fe.split("=")[1])
|
||||
ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])})
|
||||
return sorted(ckpts, key=lambda item: item["div"])[0]
|
||||
|
||||
|
||||
def stats_to_cfg(parse):
|
||||
path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(parse["path"])))
|
||||
)
|
||||
path = os.path.join(path, "config.yaml")
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
else:
|
||||
raise FileNotFoundError(f"{path} does not exists")
|
||||
|
||||
|
||||
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/"
|
||||
best = best_xp_ckpt(
|
||||
paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst"
|
||||
)
|
||||
print(best)
|
||||
print(stats_to_ckpt(best))
|
||||
print(stats_to_cfg(best))
|
Loading…
Reference in New Issue
Block a user