Fixing
This commit is contained in:
parent
d0faea89db
commit
02d994b68b
@ -12,4 +12,10 @@ def parse_args() -> argparse.Namespace:
|
||||
required=True,
|
||||
help="The explaining configuration file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
help="GPU ID if cuda available",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
@ -6,20 +6,6 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||
@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||
obj_config_to_str, read_json,
|
||||
set_printing, write_json,
|
||||
write_yaml)
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
all__captum = [
|
||||
"LRP",
|
||||
@ -122,7 +121,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
|
||||
|
||||
|
||||
class ExplainingOutline(object):
|
||||
def __init__(self, explaining_cfg_path: str):
|
||||
def __init__(self, explaining_cfg_path: str, gpu_id: int = 0):
|
||||
self.gpu_id = gpu_id
|
||||
self.explaining_cfg_path = explaining_cfg_path
|
||||
self.explaining_cfg = None
|
||||
self.explainer_cfg_path = None
|
||||
@ -254,7 +254,7 @@ class ExplainingOutline(object):
|
||||
def load_model(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
auto_select_device()
|
||||
auto_select_device(gpu_id=self.gpu_id)
|
||||
self.model = create_model()
|
||||
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
||||
if self.model is None:
|
||||
|
2
main.py
2
main.py
@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||
|
||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
CONFIG_DIR=$1
|
||||
MAX_JOBS=${3:-3}
|
||||
MAX_JOBS=${2:-3}
|
||||
GPU=${3:-0}
|
||||
SLEEP=${4:-1}
|
||||
MAIN=${5:-main}
|
||||
GPU=${6:-0}
|
||||
|
||||
(
|
||||
trap 'kill 0' SIGINT
|
||||
@ -11,7 +11,7 @@ GPU=${6:-0}
|
||||
if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then
|
||||
((CUR_JOBS >= MAX_JOBS)) && wait -n
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
python3 $MAIN.py --explaining_cfg $CONFIG &
|
||||
python3 $MAIN.py --explaining_cfg $CONFIG --gpu_id $GPU &
|
||||
echo $CONFIG
|
||||
sleep $SLEEP
|
||||
((++CUR_JOBS))
|
||||
|
Loading…
Reference in New Issue
Block a user