New features
This commit is contained in:
parent
d52bc702a2
commit
e1b2a64bd6
1 changed files with 100 additions and 0 deletions
100
explaining_framework/utils/explaining/load_ckpt.py
Normal file
100
explaining_framework/utils/explaining/load_ckpt.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
|
MODEL_STATE = "model_state"
|
||||||
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ckpt(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
ckpt_path: str,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
r"""Loads the model at given checkpoint."""
|
||||||
|
|
||||||
|
if not osp.exists(path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
ckpt = torch.load(ckpt_path)
|
||||||
|
model.load_state_dict(ckpt[MODEL_STATE])
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str:
|
||||||
|
acc = []
|
||||||
|
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
||||||
|
stats = json_to_dict_list(path)
|
||||||
|
for stat in stats:
|
||||||
|
acc.append(
|
||||||
|
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
||||||
|
)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
|
||||||
|
def xp_parser_dataset(dataset_name: str, models_dir_path) -> str:
|
||||||
|
paths = []
|
||||||
|
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
||||||
|
file = read_yaml(path)
|
||||||
|
dataset_name_ = file["dataset"]["name"]
|
||||||
|
if dataset_name == dataset_name_:
|
||||||
|
paths.append(os.path.dirname(path))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def best_xp_ckpt(paths, which: str = "best"):
|
||||||
|
acc = []
|
||||||
|
for path in paths:
|
||||||
|
accuracies = xp_stats(path)
|
||||||
|
acc.extend(accuracies)
|
||||||
|
acc = sorted(acc, key=lambda item: item["accuracy"])
|
||||||
|
if which == "best":
|
||||||
|
return acc[-1]
|
||||||
|
elif which == "worst":
|
||||||
|
return acc[0]
|
||||||
|
|
||||||
|
|
||||||
|
def stats_to_ckpt(parse):
|
||||||
|
paths = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt"
|
||||||
|
)
|
||||||
|
ckpts = []
|
||||||
|
for path in glob.glob(paths):
|
||||||
|
feat = os.path.basename(path)
|
||||||
|
feat = feat.split("-")
|
||||||
|
for fe in feat:
|
||||||
|
if "epoch" in fe:
|
||||||
|
fe_ep = int(fe.split("=")[1])
|
||||||
|
ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])})
|
||||||
|
return sorted(ckpts, key=lambda item: item["div"])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def stats_to_cfg(parse):
|
||||||
|
path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(parse["path"])))
|
||||||
|
)
|
||||||
|
path = os.path.join(path, "config.yaml")
|
||||||
|
if os.path.exists(path):
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"{path} does not exists")
|
||||||
|
|
||||||
|
|
||||||
|
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/"
|
||||||
|
best = best_xp_ckpt(
|
||||||
|
paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst"
|
||||||
|
)
|
||||||
|
print(best)
|
||||||
|
print(stats_to_ckpt(best))
|
||||||
|
print(stats_to_cfg(best))
|
Loading…
Add table
Reference in a new issue