Raising exception when any model exists
This commit is contained in:
parent
7da28de955
commit
dbf34d1679
|
@ -7,12 +7,11 @@ import logging
|
|||
import os
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
|
@ -91,6 +90,12 @@ class LoadModelInfo(object):
|
|||
dataset_name_ = file["dataset"]["name"]
|
||||
if self.dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
if len(paths) == 0:
|
||||
logging.warning(
|
||||
f"It does not exist any model trained for the dataset {self.dataset_name}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
return paths
|
||||
|
||||
def load_cfg(self, config_path):
|
||||
|
|
|
@ -6,20 +6,6 @@ import os
|
|||
from typing import Any
|
||||
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||
|
@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
|||
obj_config_to_str, read_json,
|
||||
set_printing, write_json,
|
||||
write_yaml)
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
all__captum = [
|
||||
"LRP",
|
||||
|
@ -211,6 +210,9 @@ class ExplainingOutline(object):
|
|||
model_dir=self.explaining_cfg.model.path,
|
||||
which=self.explaining_cfg.model.ckpt,
|
||||
)
|
||||
if info.list_xp() is None:
|
||||
raise ValueError
|
||||
|
||||
self.model_info = info.set_info()
|
||||
self.model_signature = info.get_model_signature()
|
||||
|
||||
|
|
Loading…
Reference in New Issue