Fixing
This commit is contained in:
parent
d0faea89db
commit
02d994b68b
|
@ -12,4 +12,10 @@ def parse_args() -> argparse.Namespace:
|
||||||
required=True,
|
required=True,
|
||||||
help="The explaining configuration file path.",
|
help="The explaining configuration file path.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu_id",
|
||||||
|
type=int,
|
||||||
|
help="GPU ID if cuda available",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
|
@ -6,20 +6,6 @@ import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric import seed_everything
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.data.makedirs import makedirs
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
|
||||||
from torch_geometric.explain.explanation import Explanation
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
from yacs.config import CfgNode as CN
|
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -45,6 +31,19 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
obj_config_to_str, read_json,
|
obj_config_to_str, read_json,
|
||||||
set_printing, write_json,
|
set_printing, write_json,
|
||||||
write_yaml)
|
write_yaml)
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -122,7 +121,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
|
||||||
|
|
||||||
|
|
||||||
class ExplainingOutline(object):
|
class ExplainingOutline(object):
|
||||||
def __init__(self, explaining_cfg_path: str):
|
def __init__(self, explaining_cfg_path: str, gpu_id: int = 0):
|
||||||
|
self.gpu_id = gpu_id
|
||||||
self.explaining_cfg_path = explaining_cfg_path
|
self.explaining_cfg_path = explaining_cfg_path
|
||||||
self.explaining_cfg = None
|
self.explaining_cfg = None
|
||||||
self.explainer_cfg_path = None
|
self.explainer_cfg_path = None
|
||||||
|
@ -254,7 +254,7 @@ class ExplainingOutline(object):
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
auto_select_device()
|
auto_select_device(gpu_id=self.gpu_id)
|
||||||
self.model = create_model()
|
self.model = create_model()
|
||||||
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
|
|
2
main.py
2
main.py
|
@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||||
|
|
||||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
CONFIG_DIR=$1
|
CONFIG_DIR=$1
|
||||||
MAX_JOBS=${3:-3}
|
MAX_JOBS=${2:-3}
|
||||||
|
GPU=${3:-0}
|
||||||
SLEEP=${4:-1}
|
SLEEP=${4:-1}
|
||||||
MAIN=${5:-main}
|
MAIN=${5:-main}
|
||||||
GPU=${6:-0}
|
|
||||||
|
|
||||||
(
|
(
|
||||||
trap 'kill 0' SIGINT
|
trap 'kill 0' SIGINT
|
||||||
|
@ -11,7 +11,7 @@ GPU=${6:-0}
|
||||||
if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then
|
if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then
|
||||||
((CUR_JOBS >= MAX_JOBS)) && wait -n
|
((CUR_JOBS >= MAX_JOBS)) && wait -n
|
||||||
export CUDA_VISIBLE_DEVICES=$GPU
|
export CUDA_VISIBLE_DEVICES=$GPU
|
||||||
python3 $MAIN.py --explaining_cfg $CONFIG &
|
python3 $MAIN.py --explaining_cfg $CONFIG --gpu_id $GPU &
|
||||||
echo $CONFIG
|
echo $CONFIG
|
||||||
sleep $SLEEP
|
sleep $SLEEP
|
||||||
((++CUR_JOBS))
|
((++CUR_JOBS))
|
||||||
|
|
Loading…
Reference in New Issue