New features

This commit is contained in:
araison 2022-12-26 15:01:18 +01:00
parent 5224207466
commit 9397934825
8 changed files with 72 additions and 77 deletions

0
__init__.py Normal file
View File

View File

@ -57,6 +57,8 @@ def set_cfg(explaining_cfg):
explaining_cfg.dataset.name = "Cora"
explaining_cfg.dataset.specific_items = None
explaining_cfg.run_topological_stat = True
# ----------------------------------------------------------------------- #
@ -80,6 +82,9 @@ def set_cfg(explaining_cfg):
# Whether or not to provide specific explaining methods configuration or default configuration
explaining_cfg.explainer.cfg = "default"
# Whether or not recomputing explanation if they already exist
explaining_cfg.explainer.force = False
# ----------------------------------------------------------------------- #
# Explaining options
# ----------------------------------------------------------------------- #
@ -98,27 +103,26 @@ def set_cfg(explaining_cfg):
# Do not modify it, we always assume here that model output are 'raw'
explaining_cfg.model_config.return_type = "raw"
# ----------------------------------------------------------------------- #
# Thresholding options
# ----------------------------------------------------------------------- #
explaining_cfg.threshold_config = CN()
explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold_config.value = 0.5
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
# Set print destination: stdout / file / both
explaining_cfg.print = "both"
explaining_cfg.threshold_config.relu_and_normalize = True
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto"
# Config name (in out_dir)
explaining_cfg.explaining_cfg_dest = "config.yaml"
explaining_cfg.seed = 0
explaining_cfg.dataset.dir = "./datasets"
explaining_cfg.relu_and_normalize = True
# which objectives metrics to computes, either all or one in particular if implemented
explaining_cfg.metrics = "all"
# Whether or not recomputing metrics if they already exist
explaining_cfg.metrics.force = False
def assert_cfg(explaining_cfg):
@ -185,8 +189,6 @@ def set_out_dir(out_dir, fname):
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.out_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.out_dir)
def set_run_dir(out_dir):
@ -200,8 +202,6 @@ def set_run_dir(out_dir):
# Make output directory
if explaining_cfg.train.auto_resume:
os.makedirs(explaining_cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(explaining_cfg.run_dir)
set_cfg(explaining_cfg)

View File

@ -0,0 +1,27 @@
import argparse
def parse_args() -> argparse.Namespace:
r"""Parses the command line arguments."""
parser = argparse.ArgumentParser(description="GraphGym")
parser.add_argument(
"--cfg",
dest="cfg_file",
type=str,
required=True,
help="The configuration file path.",
)
parser.add_argument(
"--explaining_cfg",
dest="explaining_cfg_file",
type=str,
required=True,
help="The explaining configuration file path.",
)
parser.add_argument(
"--mark_done",
action="store_true",
help="Mark yaml as done after a job has finished.",
)
return parser.parse_args()

View File

@ -0,0 +1,5 @@
class Explaining(object):
def __init__(self, cfg: dict, explaining_cfg: dict, explainer_cfg: dict = None):
self.cfg = cfg
self.explaining_cfg = explaining_cfg
self.explainer_cfg = explainer_cfg

View File

@ -1,62 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import glob
import json
import logging
import os
import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list
MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state"
def load_ckpt(
model: torch.nn.Module,
ckpt_path: str,
) -> torch.nn.Module:
r"""Loads the model at given checkpoint."""
if not osp.exists(path):
return None
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt[MODEL_STATE])
return model
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/test_cifar/"
FOLDER = "graph_classif_base-dataset=PCBA-l_mp=2-l_post_mp=3-dim_inner=64-layer_type=gatconv-graph_pooling=mean"
def xp_accuracies(path_to_xp: str, wrt_metric: str = "val") -> str:
acc = []
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
stats = json_to_dict_list(path)
for stat in stats:
acc.append(
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
)
return acc
# return sorted(acc, key=lambda item: item["accuracy"])
def best_ckpt_path(dataset_name: str, models_dir_path) -> str:
paths = []
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
file = read_yaml(path)
dataset_name_ = file["dataset"]["name"]
if dataset_name == dataset_name_:
paths.append(os.path.dirname(path))
return paths
print(load_best_given_exp(PATH))

View File

@ -0,0 +1,25 @@
import json
import yaml
def read_json(path: str) -> dict:
with open(path, "r") as f:
data = json.load(f)
return data
def write_json(data: dict, path: str) -> None:
with open(path, "w") as f:
data = json.dump(data, f)
def read_yaml(path: str) -> dict:
with open(path, "r") as f:
data = yaml.safe_load(f)
return data
def write_yaml(data: dict, path: str) -> None:
with open(path, "w") as f:
data = yaml.dump(data, f)

0
main.py Normal file
View File