From 02d994b68b5c14824ac63353f748cb5a42aca77a Mon Sep 17 00:00:00 2001 From: araison Date: Tue, 10 Jan 2023 11:27:54 +0100 Subject: [PATCH] Fixing --- .../utils/explaining/cmd_args.py | 6 ++++ .../utils/explaining/outline.py | 32 +++++++++---------- main.py | 2 +- parallel.sh | 6 ++-- 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/explaining_framework/utils/explaining/cmd_args.py b/explaining_framework/utils/explaining/cmd_args.py index 6c2afc7..2174a72 100644 --- a/explaining_framework/utils/explaining/cmd_args.py +++ b/explaining_framework/utils/explaining/cmd_args.py @@ -12,4 +12,10 @@ def parse_args() -> argparse.Namespace: required=True, help="The explaining configuration file path.", ) + parser.add_argument( + "--gpu_id", + type=int, + help="GPU ID if cuda available", + ) + return parser.parse_args() diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 967c524..10413fc 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -6,20 +6,6 @@ import os from typing import Any from eixgnn.eixgnn import EiXGNN -from scgnn.scgnn import SCGNN -from torch_geometric import seed_everything -from torch_geometric.data import Batch, Data -from torch_geometric.data.makedirs import makedirs -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.explain.explanation import Explanation -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader -from yacs.config import CfgNode as CN - from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, obj_config_to_str, read_json, set_printing, write_json, write_yaml) +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.data.makedirs import makedirs +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader +from yacs.config import CfgNode as CN all__captum = [ "LRP", @@ -122,7 +121,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"] class ExplainingOutline(object): - def __init__(self, explaining_cfg_path: str): + def __init__(self, explaining_cfg_path: str, gpu_id: int = 0): + self.gpu_id = gpu_id self.explaining_cfg_path = explaining_cfg_path self.explaining_cfg = None self.explainer_cfg_path = None @@ -254,7 +254,7 @@ class ExplainingOutline(object): def load_model(self): if self.cfg is None: self.load_cfg() - auto_select_device() + auto_select_device(gpu_id=self.gpu_id) self.model = create_model() self.model = _load_ckpt(self.model, self.model_info["ckpt_path"]) if self.model is None: diff --git a/main.py b/main.py index bcdbda7..b361fe2 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, if __name__ == "__main__": args = parse_args() - outline = ExplainingOutline(args.explaining_cfg_file) + outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) diff --git a/parallel.sh b/parallel.sh index 8083a20..6fc7a92 100644 --- a/parallel.sh +++ b/parallel.sh @@ -1,8 +1,8 @@ CONFIG_DIR=$1 -MAX_JOBS=${3:-3} +MAX_JOBS=${2:-3} +GPU=${3:-0} SLEEP=${4:-1} MAIN=${5:-main} -GPU=${6:-0} ( trap 'kill 0' SIGINT @@ -11,7 +11,7 @@ GPU=${6:-0} if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then ((CUR_JOBS >= MAX_JOBS)) && wait -n export CUDA_VISIBLE_DEVICES=$GPU - python3 $MAIN.py --explaining_cfg $CONFIG & + python3 $MAIN.py --explaining_cfg $CONFIG --gpu_id $GPU & echo $CONFIG sleep $SLEEP ((++CUR_JOBS))