This commit is contained in:
araison 2023-01-10 11:27:54 +01:00
parent d0faea89db
commit 02d994b68b
4 changed files with 26 additions and 20 deletions

View File

@ -12,4 +12,10 @@ def parse_args() -> argparse.Namespace:
required=True,
help="The explaining configuration file path.",
)
parser.add_argument(
"--gpu_id",
type=int,
help="GPU ID if cuda available",
)
return parser.parse_args()

View File

@ -6,20 +6,6 @@ import os
from typing import Any
from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from yacs.config import CfgNode as CN
from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
obj_config_to_str, read_json,
set_printing, write_json,
write_yaml)
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from yacs.config import CfgNode as CN
all__captum = [
"LRP",
@ -122,7 +121,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str):
def __init__(self, explaining_cfg_path: str, gpu_id: int = 0):
self.gpu_id = gpu_id
self.explaining_cfg_path = explaining_cfg_path
self.explaining_cfg = None
self.explainer_cfg_path = None
@ -254,7 +254,7 @@ class ExplainingOutline(object):
def load_model(self):
if self.cfg is None:
self.load_cfg()
auto_select_device()
auto_select_device(gpu_id=self.gpu_id)
self.model = create_model()
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
if self.model is None:

View File

@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file)
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))

View File

@ -1,8 +1,8 @@
CONFIG_DIR=$1
MAX_JOBS=${3:-3}
MAX_JOBS=${2:-3}
GPU=${3:-0}
SLEEP=${4:-1}
MAIN=${5:-main}
GPU=${6:-0}
(
trap 'kill 0' SIGINT
@ -11,7 +11,7 @@ GPU=${6:-0}
if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then
((CUR_JOBS >= MAX_JOBS)) && wait -n
export CUDA_VISIBLE_DEVICES=$GPU
python3 $MAIN.py --explaining_cfg $CONFIG &
python3 $MAIN.py --explaining_cfg $CONFIG --gpu_id $GPU &
echo $CONFIG
sleep $SLEEP
((++CUR_JOBS))