From dbf34d1679605830483e0bfda3507af73069d89f Mon Sep 17 00:00:00 2001 From: araison Date: Mon, 9 Jan 2023 19:49:26 +0100 Subject: [PATCH] Raising exception when any model exists --- .../utils/explaining/load_ckpt.py | 11 +++++-- .../utils/explaining/outline.py | 30 ++++++++++--------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index 25e00d9..3975819 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -7,12 +7,11 @@ import logging import os import torch +from explaining_framework.utils.io import read_yaml from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.utils.io import json_to_dict_list -from explaining_framework.utils.io import read_yaml - MODEL_STATE = "model_state" OPTIMIZER_STATE = "optimizer_state" SCHEDULER_STATE = "scheduler_state" @@ -91,7 +90,13 @@ class LoadModelInfo(object): dataset_name_ = file["dataset"]["name"] if self.dataset_name == dataset_name_: paths.append(os.path.dirname(path)) - return paths + if len(paths) == 0: + logging.warning( + f"It does not exist any model trained for the dataset {self.dataset_name}" + ) + return None + else: + return paths def load_cfg(self, config_path): return read_yaml(config_path) diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 0390ca6..f5af4c4 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -6,20 +6,6 @@ import os from typing import Any from eixgnn.eixgnn import EiXGNN -from scgnn.scgnn import SCGNN -from torch_geometric import seed_everything -from torch_geometric.data import Batch, Data -from torch_geometric.data.makedirs import makedirs -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.explain.explanation import Explanation -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader -from yacs.config import CfgNode as CN - from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, obj_config_to_str, read_json, set_printing, write_json, write_yaml) +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.data.makedirs import makedirs +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader +from yacs.config import CfgNode as CN all__captum = [ "LRP", @@ -211,6 +210,9 @@ class ExplainingOutline(object): model_dir=self.explaining_cfg.model.path, which=self.explaining_cfg.model.ckpt, ) + if info.list_xp() is None: + raise ValueError + self.model_info = info.set_info() self.model_signature = info.get_model_signature()