Raising exception when any model exists

This commit is contained in:
araison 2023-01-09 19:49:26 +01:00
parent 7da28de955
commit dbf34d1679
2 changed files with 24 additions and 17 deletions

View File

@ -7,12 +7,11 @@ import logging
import os import os
import torch import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state" MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state" OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state" SCHEDULER_STATE = "scheduler_state"
@ -91,7 +90,13 @@ class LoadModelInfo(object):
dataset_name_ = file["dataset"]["name"] dataset_name_ = file["dataset"]["name"]
if self.dataset_name == dataset_name_: if self.dataset_name == dataset_name_:
paths.append(os.path.dirname(path)) paths.append(os.path.dirname(path))
return paths if len(paths) == 0:
logging.warning(
f"It does not exist any model trained for the dataset {self.dataset_name}"
)
return None
else:
return paths
def load_cfg(self, config_path): def load_cfg(self, config_path):
return read_yaml(config_path) return read_yaml(config_path)

View File

@ -6,20 +6,6 @@ import os
from typing import Any from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from yacs.config import CfgNode as CN
from explaining_framework.config.explainer_config.eixgnn_config import \ from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
obj_config_to_str, read_json, obj_config_to_str, read_json,
set_printing, write_json, set_printing, write_json,
write_yaml) write_yaml)
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from yacs.config import CfgNode as CN
all__captum = [ all__captum = [
"LRP", "LRP",
@ -211,6 +210,9 @@ class ExplainingOutline(object):
model_dir=self.explaining_cfg.model.path, model_dir=self.explaining_cfg.model.path,
which=self.explaining_cfg.model.ckpt, which=self.explaining_cfg.model.ckpt,
) )
if info.list_xp() is None:
raise ValueError
self.model_info = info.set_info() self.model_info = info.set_info()
self.model_signature = info.get_model_signature() self.model_signature = info.get_model_signature()