New features
This commit is contained in:
parent
5224207466
commit
9397934825
|
@ -57,6 +57,8 @@ def set_cfg(explaining_cfg):
|
|||
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
explaining_cfg.dataset.specific_items = None
|
||||
|
||||
explaining_cfg.run_topological_stat = True
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
@ -80,6 +82,9 @@ def set_cfg(explaining_cfg):
|
|||
# Whether or not to provide specific explaining methods configuration or default configuration
|
||||
explaining_cfg.explainer.cfg = "default"
|
||||
|
||||
# Whether or not recomputing explanation if they already exist
|
||||
explaining_cfg.explainer.force = False
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Explaining options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
@ -98,27 +103,26 @@ def set_cfg(explaining_cfg):
|
|||
# Do not modify it, we always assume here that model output are 'raw'
|
||||
explaining_cfg.model_config.return_type = "raw"
|
||||
|
||||
# ----------------------------------------------------------------------- #
|
||||
# Thresholding options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.threshold_config = CN()
|
||||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
|
||||
explaining_cfg.threshold_config.value = 0.5
|
||||
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
|
||||
|
||||
# Set print destination: stdout / file / both
|
||||
explaining_cfg.print = "both"
|
||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# Config name (in out_dir)
|
||||
explaining_cfg.explaining_cfg_dest = "config.yaml"
|
||||
|
||||
explaining_cfg.seed = 0
|
||||
|
||||
explaining_cfg.dataset.dir = "./datasets"
|
||||
|
||||
explaining_cfg.relu_and_normalize = True
|
||||
# which objectives metrics to computes, either all or one in particular if implemented
|
||||
explaining_cfg.metrics = "all"
|
||||
|
||||
# Whether or not recomputing metrics if they already exist
|
||||
explaining_cfg.metrics.force = False
|
||||
|
||||
|
||||
def assert_cfg(explaining_cfg):
|
||||
|
@ -185,8 +189,6 @@ def set_out_dir(out_dir, fname):
|
|||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
|
@ -200,8 +202,6 @@ def set_run_dir(out_dir):
|
|||
# Make output directory
|
||||
if explaining_cfg.train.auto_resume:
|
||||
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(explaining_cfg.run_dir)
|
||||
|
||||
|
||||
set_cfg(explaining_cfg)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import argparse
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
r"""Parses the command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="GraphGym")
|
||||
|
||||
parser.add_argument(
|
||||
"--cfg",
|
||||
dest="cfg_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The configuration file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--explaining_cfg",
|
||||
dest="explaining_cfg_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The explaining configuration file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mark_done",
|
||||
action="store_true",
|
||||
help="Mark yaml as done after a job has finished.",
|
||||
)
|
||||
return parser.parse_args()
|
|
@ -0,0 +1,5 @@
|
|||
class Explaining(object):
|
||||
def __init__(self, cfg: dict, explaining_cfg: dict, explainer_cfg: dict = None):
|
||||
self.cfg = cfg
|
||||
self.explaining_cfg = explaining_cfg
|
||||
self.explainer_cfg = explainer_cfg
|
|
@ -1,62 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
||||
|
||||
def load_ckpt(
|
||||
model: torch.nn.Module,
|
||||
ckpt_path: str,
|
||||
) -> torch.nn.Module:
|
||||
r"""Loads the model at given checkpoint."""
|
||||
|
||||
if not osp.exists(path):
|
||||
return None
|
||||
|
||||
ckpt = torch.load(ckpt_path)
|
||||
model.load_state_dict(ckpt[MODEL_STATE])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/test_cifar/"
|
||||
|
||||
FOLDER = "graph_classif_base-dataset=PCBA-l_mp=2-l_post_mp=3-dim_inner=64-layer_type=gatconv-graph_pooling=mean"
|
||||
|
||||
|
||||
def xp_accuracies(path_to_xp: str, wrt_metric: str = "val") -> str:
|
||||
acc = []
|
||||
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
||||
stats = json_to_dict_list(path)
|
||||
for stat in stats:
|
||||
acc.append(
|
||||
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
||||
)
|
||||
return acc
|
||||
# return sorted(acc, key=lambda item: item["accuracy"])
|
||||
|
||||
|
||||
def best_ckpt_path(dataset_name: str, models_dir_path) -> str:
|
||||
paths = []
|
||||
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
||||
file = read_yaml(path)
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
return paths
|
||||
|
||||
|
||||
print(load_best_given_exp(PATH))
|
|
@ -0,0 +1,25 @@
|
|||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def read_json(path: str) -> dict:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def write_json(data: dict, path: str) -> None:
|
||||
with open(path, "w") as f:
|
||||
data = json.dump(data, f)
|
||||
|
||||
|
||||
def read_yaml(path: str) -> dict:
|
||||
with open(path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data
|
||||
|
||||
|
||||
def write_yaml(data: dict, path: str) -> None:
|
||||
with open(path, "w") as f:
|
||||
data = yaml.dump(data, f)
|
Loading…
Reference in New Issue