Fixing bugs and adding new features
This commit is contained in:
parent
e1b2a64bd6
commit
7c4f5ca196
|
@ -66,17 +66,18 @@ set_eixgnn_cfg(eixgnn_cfg)
|
||||||
# eixgnn_cfg.dump(stream=f)
|
# eixgnn_cfg.dump(stream=f)
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# def load_eixgnn_cfg(eixgnn_cfg, args):
|
def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||||
# r"""
|
r"""
|
||||||
# Load configurations from file system and command line
|
Load configurations from file system and command line
|
||||||
# Args:
|
Args:
|
||||||
# eixgnn_cfg (CfgNode): Configuration node
|
eixgnn_cfg (CfgNode): Configuration node
|
||||||
# args (ArgumentParser): Command argument parser
|
args (ArgumentParser): Command argument parser
|
||||||
# """
|
"""
|
||||||
# eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||||
# eixgnn_cfg.merge_from_list(args.opts)
|
eixgnn_cfg.merge_from_list(args.opts)
|
||||||
# assert_eixgnn_cfg(eixgnn_cfg)
|
assert_eixgnn_cfg(eixgnn_cfg)
|
||||||
#
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# def makedirs_rm_exist(dir):
|
# def makedirs_rm_exist(dir):
|
||||||
# if os.path.isdir(dir):
|
# if os.path.isdir(dir):
|
||||||
|
|
|
@ -36,27 +36,25 @@ def set_scgnn_cfg(scgnn_cfg):
|
||||||
if scgnn_cfg is None:
|
if scgnn_cfg is None:
|
||||||
return scgnn_cfg
|
return scgnn_cfg
|
||||||
|
|
||||||
scgnn_cfg.depth = 'all'
|
scgnn_cfg.depth = "all"
|
||||||
scgnn_cfg.interest_map_norm = True
|
scgnn_cfg.interest_map_norm = True
|
||||||
scgnn_cfg.score_map_norm = True
|
scgnn_cfg.score_map_norm = True
|
||||||
|
|
||||||
|
|
||||||
|
def assert_cfg(scgnn_cfg):
|
||||||
|
|
||||||
def assert_scgnn_cfg(scgnn_cfg):
|
|
||||||
r"""Checks config values, do necessary post processing to the configs
|
r"""Checks config values, do necessary post processing to the configs
|
||||||
TODO
|
TODO
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
# if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"Task {} not supported, must be one of node, "
|
# "Task {} not supported, must be one of node, "
|
||||||
"edge, graph, link_pred".format(scgnn_cfg.dataset.task)
|
# "edge, graph, link_pred".format(scgnn_cfg.dataset.task)
|
||||||
)
|
# )
|
||||||
scgnn_cfg.run_dir = scgnn_cfg.out_dir
|
# scgnn_cfg.run_dir = scgnn_cfg.out_dir
|
||||||
|
|
||||||
|
|
||||||
def dump_scgnn_cfg(scgnn_cfg,path):
|
def dump_cfg(scgnn_cfg, path):
|
||||||
r"""
|
r"""
|
||||||
TODO
|
TODO
|
||||||
Dumps the config to the output directory specified in
|
Dumps the config to the output directory specified in
|
||||||
|
@ -65,14 +63,12 @@ def dump_scgnn_cfg(scgnn_cfg,path):
|
||||||
scgnn_cfg (CfgNode): Configuration node
|
scgnn_cfg (CfgNode): Configuration node
|
||||||
"""
|
"""
|
||||||
makedirs(scgnn_cfg.out_dir)
|
makedirs(scgnn_cfg.out_dir)
|
||||||
scgnn_cfg_file = os.path.join(
|
scgnn_cfg_file = os.path.join(scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest)
|
||||||
scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest
|
|
||||||
)
|
|
||||||
with open(scgnn_cfg_file, "w") as f:
|
with open(scgnn_cfg_file, "w") as f:
|
||||||
scgnn_cfg.dump(stream=f)
|
scgnn_cfg.dump(stream=f)
|
||||||
|
|
||||||
|
|
||||||
def load_scgnn_cfg(scgnn_cfg, args):
|
def load_cfg(scgnn_cfg, args):
|
||||||
r"""
|
r"""
|
||||||
Load configurations from file system and command line
|
Load configurations from file system and command line
|
||||||
Args:
|
Args:
|
||||||
|
@ -116,8 +112,6 @@ def set_out_dir(out_dir, fname):
|
||||||
# Make output directory
|
# Make output directory
|
||||||
if scgnn_cfg.train.auto_resume:
|
if scgnn_cfg.train.auto_resume:
|
||||||
os.makedirs(scgnn_cfg.out_dir, exist_ok=True)
|
os.makedirs(scgnn_cfg.out_dir, exist_ok=True)
|
||||||
else:
|
|
||||||
makedirs_rm_exist(scgnn_cfg.out_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def set_run_dir(out_dir):
|
def set_run_dir(out_dir):
|
||||||
|
@ -131,8 +125,6 @@ def set_run_dir(out_dir):
|
||||||
# Make output directory
|
# Make output directory
|
||||||
if scgnn_cfg.train.auto_resume:
|
if scgnn_cfg.train.auto_resume:
|
||||||
os.makedirs(scgnn_cfg.run_dir, exist_ok=True)
|
os.makedirs(scgnn_cfg.run_dir, exist_ok=True)
|
||||||
else:
|
|
||||||
makedirs_rm_exist(scgnn_cfg.run_dir)
|
|
||||||
|
|
||||||
|
|
||||||
set_scgnn_cfg(scgnn_cfg)
|
set_scgnn_cfg(scgnn_cfg)
|
||||||
|
|
|
@ -122,7 +122,8 @@ def set_cfg(explaining_cfg):
|
||||||
explaining_cfg.accelerator = "auto"
|
explaining_cfg.accelerator = "auto"
|
||||||
|
|
||||||
# which objectives metrics to computes, either all or one in particular if implemented
|
# which objectives metrics to computes, either all or one in particular if implemented
|
||||||
explaining_cfg.metrics = "all"
|
explaining_cfg.metrics = CN()
|
||||||
|
explaining_cfg.metrics.type = "all"
|
||||||
|
|
||||||
# Whether or not recomputing metrics if they already exist
|
# Whether or not recomputing metrics if they already exist
|
||||||
explaining_cfg.metrics.force = False
|
explaining_cfg.metrics.force = False
|
||||||
|
|
|
@ -86,7 +86,7 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
if criterion == "mse":
|
if criterion == "mse":
|
||||||
loss = MSELoss()
|
loss = MSELoss()
|
||||||
return loss
|
return loss
|
||||||
elif criterion == "cross-entropy":
|
elif criterion == "cross_entropy":
|
||||||
loss = CrossEntropyLoss()
|
loss = CrossEntropyLoss()
|
||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,13 +5,6 @@ def parse_args() -> argparse.Namespace:
|
||||||
r"""Parses the command line arguments."""
|
r"""Parses the command line arguments."""
|
||||||
parser = argparse.ArgumentParser(description="ExplainingFramework")
|
parser = argparse.ArgumentParser(description="ExplainingFramework")
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cfg",
|
|
||||||
dest="cfg_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The configuration file path.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--explaining_cfg",
|
"--explaining_cfg",
|
||||||
dest="explaining_cfg_file",
|
dest="explaining_cfg_file",
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
class ExplainingOutline(object):
|
|
||||||
def __init__(self, explaining_cfg: CN, explainer_cfg: CN = None):
|
|
||||||
self.explaining_cfg = explaining_cfg
|
|
||||||
self.explainer_cfg = explainer_cfg
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(self):
|
|
||||||
if self
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from eixgnn.eixgnn import EiXGNN
|
||||||
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
|
eixgnn_cfg
|
||||||
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
|
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||||
|
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||||
|
GraphXAIWrapper
|
||||||
|
from explaining_framework.metric.accuracy import Accuracy
|
||||||
|
from explaining_framework.metric.fidelity import Fidelity
|
||||||
|
from explaining_framework.metric.robust import Attack
|
||||||
|
from explaining_framework.metric.sparsity import Sparsity
|
||||||
|
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
||||||
|
_load_ckpt)
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
|
||||||
|
all__captum = [
|
||||||
|
"LRP",
|
||||||
|
"DeepLift",
|
||||||
|
"DeepLiftShap",
|
||||||
|
"FeatureAblation",
|
||||||
|
"FeaturePermutation",
|
||||||
|
"GradientShap",
|
||||||
|
"GuidedBackprop",
|
||||||
|
"GuidedGradCam",
|
||||||
|
"InputXGradient",
|
||||||
|
"IntegratedGradients",
|
||||||
|
"Lime",
|
||||||
|
"Occlusion",
|
||||||
|
"Saliency",
|
||||||
|
]
|
||||||
|
|
||||||
|
all__graphxai = [
|
||||||
|
"CAM",
|
||||||
|
"GradCAM",
|
||||||
|
"GNN_LRP",
|
||||||
|
"GradExplainer",
|
||||||
|
"GuidedBackPropagation",
|
||||||
|
"IntegratedGradients",
|
||||||
|
"PGExplainer",
|
||||||
|
"PGMExplainer",
|
||||||
|
"RandomExplainer",
|
||||||
|
"SubgraphX",
|
||||||
|
"GraphMASK",
|
||||||
|
]
|
||||||
|
|
||||||
|
all__own = ["EIXGNN", "SCGNN"]
|
||||||
|
|
||||||
|
|
||||||
|
class ExplainingOutline(object):
|
||||||
|
def __init__(self, explaining_cfg_path: str):
|
||||||
|
self.explaining_cfg_path = explaining_cfg_path
|
||||||
|
self.explaining_cfg = None
|
||||||
|
self.explainer_cfg_path = None
|
||||||
|
self.explainer_cfg = None
|
||||||
|
self.explaining_algorithm = None
|
||||||
|
self.cfg = None
|
||||||
|
self.model = None
|
||||||
|
self.dataset = None
|
||||||
|
self.model_info = None
|
||||||
|
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
self.load_model_info()
|
||||||
|
self.load_cfg()
|
||||||
|
self.load_dataset()
|
||||||
|
self.load_model()
|
||||||
|
self.load_explainer_cfg()
|
||||||
|
self.load_explainer()
|
||||||
|
|
||||||
|
def load_model_info(self):
|
||||||
|
info = LoadModelInfo(
|
||||||
|
dataset_name=self.explaining_cfg.dataset.name,
|
||||||
|
model_dir=self.explaining_cfg.model.path,
|
||||||
|
which=self.explaining_cfg.model.ckpt,
|
||||||
|
)
|
||||||
|
self.model_info = info.set_info()
|
||||||
|
|
||||||
|
def load_cfg(self):
|
||||||
|
cfg.set_new_allowed(True)
|
||||||
|
cfg.merge_from_file(self.model_info["cfg_path"])
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def load_explaining_cfg(self):
|
||||||
|
explaining_cfg.set_new_allowed(True)
|
||||||
|
explaining_cfg.merge_from_file(self.explaining_cfg_path)
|
||||||
|
self.explaining_cfg = explaining_cfg
|
||||||
|
|
||||||
|
def load_explainer_cfg(self):
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.explaining_cfg()
|
||||||
|
else:
|
||||||
|
if self.explaining_cfg.explainer.cfg == "default":
|
||||||
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||||
|
self.explainer_cfg = copy.copy(eixgnn_cfg)
|
||||||
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||||
|
self.explainer_cfg = copy.copy(scgnn_cfg)
|
||||||
|
else:
|
||||||
|
self.explainer_cfg = None
|
||||||
|
else:
|
||||||
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||||
|
eixgnn_cfg.set_new_allowed(True)
|
||||||
|
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
|
self.explainer_cfg = eixgnn_cfg
|
||||||
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||||
|
scgnn_cfg.set_new_allowed(True)
|
||||||
|
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
|
self.explainer_cfg = scgnn_cfg
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
auto_select_device()
|
||||||
|
self.model = create_model()
|
||||||
|
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
||||||
|
|
||||||
|
def load_dataset(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.explaining_cfg()
|
||||||
|
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||||
|
)
|
||||||
|
self.dataset = create_dataset()
|
||||||
|
|
||||||
|
def load_explainer(self):
|
||||||
|
self.load_explainer_cfg()
|
||||||
|
if self.model is None:
|
||||||
|
self.load_model()
|
||||||
|
if self.dataset is None:
|
||||||
|
self.load_dataset()
|
||||||
|
|
||||||
|
name = self.explaining_cfg.explainer.name
|
||||||
|
if name in all__captum:
|
||||||
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
|
elif name in all__graphxai:
|
||||||
|
explaining_algorithm = GraphXAIWrapper(
|
||||||
|
name,
|
||||||
|
in_channels=self.dataset.num_classes,
|
||||||
|
criterion=self.cfg.model.loss_fun,
|
||||||
|
)
|
||||||
|
elif name in all__own:
|
||||||
|
if name == "EIXGNN":
|
||||||
|
explaining_algorithm = EiXGNN(
|
||||||
|
L=self.explainer_cfg.L,
|
||||||
|
p=self.explainer_cfg.p,
|
||||||
|
importance_sampling_strategy=self.explainer_cfg.importance_sampling_strategy,
|
||||||
|
domain_similarity=self.explainer_cfg.domain_similarity,
|
||||||
|
signal_similarity=self.explainer_cfg.signal_similarity,
|
||||||
|
shap_val_approx=self.explainer_cfg.shapley_value_approx,
|
||||||
|
)
|
||||||
|
elif name == "SCGNN":
|
||||||
|
explaining_algorithm = SCGNN(
|
||||||
|
depth=self.explainer_cfg.depth,
|
||||||
|
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
||||||
|
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||||
|
)
|
||||||
|
self.explaining_algorithm = explaining_algorithm
|
||||||
|
print(self.explaining_algorithm.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
PATH = "config_exp.yaml"
|
||||||
|
test = ExplainingOutline(explaining_cfg_path=PATH)
|
|
@ -23,78 +23,104 @@ def _load_ckpt(
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
r"""Loads the model at given checkpoint."""
|
r"""Loads the model at given checkpoint."""
|
||||||
|
|
||||||
if not osp.exists(path):
|
if not os.path.exists(ckpt_path):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ckpt = torch.load(ckpt_path)
|
ckpt = torch.load(ckpt_path)
|
||||||
model.load_state_dict(ckpt[MODEL_STATE])
|
model.load_state_dict(ckpt["state_dict"])
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str:
|
class LoadModelInfo(object):
|
||||||
acc = []
|
def __init__(
|
||||||
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
self, dataset_name: str, model_dir: str, which: str, wrt_metric: str = "val"
|
||||||
stats = json_to_dict_list(path)
|
):
|
||||||
for stat in stats:
|
self.dataset_name = dataset_name
|
||||||
acc.append(
|
self.model_dir = model_dir
|
||||||
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
self.which = which
|
||||||
|
self.wrt_metric = wrt_metric
|
||||||
|
self.info = None
|
||||||
|
|
||||||
|
def list_stats(self, path) -> list:
|
||||||
|
info = []
|
||||||
|
for path in glob.glob(
|
||||||
|
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
|
||||||
|
):
|
||||||
|
stats = json_to_dict_list(path)
|
||||||
|
for stat in stats:
|
||||||
|
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
|
||||||
|
ckpt_dir_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(path)), "ckpt"
|
||||||
|
)
|
||||||
|
cfg_path = os.path.join(xp_dir_path, "config.yaml")
|
||||||
|
epoch = stat["epoch"]
|
||||||
|
accuracy = stat["accuracy"]
|
||||||
|
loss = stat["loss"]
|
||||||
|
lr = stat["lr"]
|
||||||
|
params = stat["params"]
|
||||||
|
time_iter = stat["time_iter"]
|
||||||
|
ckpt_path = self.get_ckpt_path(epoch=epoch, ckpt_dir_path=ckpt_dir_path)
|
||||||
|
info.append(
|
||||||
|
{
|
||||||
|
"xp_dir_path": xp_dir_path,
|
||||||
|
"ckpt_path": self.get_ckpt_path(
|
||||||
|
epoch=epoch, ckpt_dir_path=ckpt_dir_path
|
||||||
|
),
|
||||||
|
"cfg_path": cfg_path,
|
||||||
|
"epoch": epoch,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"loss": loss,
|
||||||
|
"lr": lr,
|
||||||
|
"params": params,
|
||||||
|
"time_iter": time_iter,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def list_xp(self):
|
||||||
|
paths = []
|
||||||
|
for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")):
|
||||||
|
file = self.load_cfg(path)
|
||||||
|
dataset_name_ = file["dataset"]["name"]
|
||||||
|
if self.dataset_name == dataset_name_:
|
||||||
|
paths.append(os.path.dirname(path))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
def load_cfg(self, config_path):
|
||||||
|
return read_yaml(config_path)
|
||||||
|
|
||||||
|
def set_info(self):
|
||||||
|
if self.which in ["best", "worst"] or isinstance(self.which, int):
|
||||||
|
paths = self.list_xp()
|
||||||
|
infos = []
|
||||||
|
for path in paths:
|
||||||
|
info = self.list_stats(path)
|
||||||
|
infos.extend(info)
|
||||||
|
infos = sorted(infos, key=lambda item: item["accuracy"])
|
||||||
|
if self.which == "best":
|
||||||
|
self.info = infos[-1]
|
||||||
|
elif self.which == "worst":
|
||||||
|
self.info = infos[0]
|
||||||
|
elif isinstance(self.which, int):
|
||||||
|
self.info = infos[self.which]
|
||||||
|
else:
|
||||||
|
specific_xp_dir_path = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(self.which))
|
||||||
)
|
)
|
||||||
return acc
|
stats = self.list_stats(specific_xp_dir_path)
|
||||||
|
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
|
||||||
|
return self.info
|
||||||
|
|
||||||
|
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
|
||||||
def xp_parser_dataset(dataset_name: str, models_dir_path) -> str:
|
paths = os.path.join(ckpt_dir_path, "*.ckpt")
|
||||||
paths = []
|
ckpts = []
|
||||||
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
for path in glob.glob(paths):
|
||||||
file = read_yaml(path)
|
feat = os.path.basename(path)
|
||||||
dataset_name_ = file["dataset"]["name"]
|
feat = feat.split("-")
|
||||||
if dataset_name == dataset_name_:
|
for fe in feat:
|
||||||
paths.append(os.path.dirname(path))
|
if "epoch" in fe:
|
||||||
return paths
|
fe_ep = int(fe.split("=")[1])
|
||||||
|
ckpts.append({"path": path, "div": abs(fe_ep - epoch)})
|
||||||
|
ckpt = sorted(ckpts, key=lambda item: item["div"])[0]
|
||||||
def best_xp_ckpt(paths, which: str = "best"):
|
return ckpt["path"]
|
||||||
acc = []
|
|
||||||
for path in paths:
|
|
||||||
accuracies = xp_stats(path)
|
|
||||||
acc.extend(accuracies)
|
|
||||||
acc = sorted(acc, key=lambda item: item["accuracy"])
|
|
||||||
if which == "best":
|
|
||||||
return acc[-1]
|
|
||||||
elif which == "worst":
|
|
||||||
return acc[0]
|
|
||||||
|
|
||||||
|
|
||||||
def stats_to_ckpt(parse):
|
|
||||||
paths = os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt"
|
|
||||||
)
|
|
||||||
ckpts = []
|
|
||||||
for path in glob.glob(paths):
|
|
||||||
feat = os.path.basename(path)
|
|
||||||
feat = feat.split("-")
|
|
||||||
for fe in feat:
|
|
||||||
if "epoch" in fe:
|
|
||||||
fe_ep = int(fe.split("=")[1])
|
|
||||||
ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])})
|
|
||||||
return sorted(ckpts, key=lambda item: item["div"])[0]
|
|
||||||
|
|
||||||
|
|
||||||
def stats_to_cfg(parse):
|
|
||||||
path = os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.dirname(parse["path"])))
|
|
||||||
)
|
|
||||||
path = os.path.join(path, "config.yaml")
|
|
||||||
if os.path.exists(path):
|
|
||||||
return path
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"{path} does not exists")
|
|
||||||
|
|
||||||
|
|
||||||
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/"
|
|
||||||
best = best_xp_ckpt(
|
|
||||||
paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst"
|
|
||||||
)
|
|
||||||
print(best)
|
|
||||||
print(stats_to_ckpt(best))
|
|
||||||
print(stats_to_cfg(best))
|
|
||||||
|
|
Loading…
Reference in New Issue