Raising exception when any model exists
This commit is contained in:
parent
7da28de955
commit
dbf34d1679
|
@ -7,12 +7,11 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
from explaining_framework.utils.io import read_yaml
|
|
||||||
|
|
||||||
MODEL_STATE = "model_state"
|
MODEL_STATE = "model_state"
|
||||||
OPTIMIZER_STATE = "optimizer_state"
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
SCHEDULER_STATE = "scheduler_state"
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
@ -91,7 +90,13 @@ class LoadModelInfo(object):
|
||||||
dataset_name_ = file["dataset"]["name"]
|
dataset_name_ = file["dataset"]["name"]
|
||||||
if self.dataset_name == dataset_name_:
|
if self.dataset_name == dataset_name_:
|
||||||
paths.append(os.path.dirname(path))
|
paths.append(os.path.dirname(path))
|
||||||
return paths
|
if len(paths) == 0:
|
||||||
|
logging.warning(
|
||||||
|
f"It does not exist any model trained for the dataset {self.dataset_name}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return paths
|
||||||
|
|
||||||
def load_cfg(self, config_path):
|
def load_cfg(self, config_path):
|
||||||
return read_yaml(config_path)
|
return read_yaml(config_path)
|
||||||
|
|
|
@ -6,20 +6,6 @@ import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric import seed_everything
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.data.makedirs import makedirs
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
|
||||||
from torch_geometric.explain.explanation import Explanation
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
from yacs.config import CfgNode as CN
|
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
obj_config_to_str, read_json,
|
obj_config_to_str, read_json,
|
||||||
set_printing, write_json,
|
set_printing, write_json,
|
||||||
write_yaml)
|
write_yaml)
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -211,6 +210,9 @@ class ExplainingOutline(object):
|
||||||
model_dir=self.explaining_cfg.model.path,
|
model_dir=self.explaining_cfg.model.path,
|
||||||
which=self.explaining_cfg.model.ckpt,
|
which=self.explaining_cfg.model.ckpt,
|
||||||
)
|
)
|
||||||
|
if info.list_xp() is None:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
self.model_info = info.set_info()
|
self.model_info = info.set_info()
|
||||||
self.model_signature = info.get_model_signature()
|
self.model_signature = info.get_model_signature()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue