Adding Force FastForward method

This commit is contained in:
araison 2023-01-31 10:19:17 +01:00
parent 56a62df848
commit 148e54717a
4 changed files with 17 additions and 4 deletions

View File

@ -48,7 +48,7 @@ def explaining_conf(
explaining_cfg["explainer"] = {} explaining_cfg["explainer"] = {}
explaining_cfg["explainer"]["cfg"] = explainer_config explaining_cfg["explainer"]["cfg"] = explainer_config
explaining_cfg["explainer"]["name"] = explainer explaining_cfg["explainer"]["name"] = explainer
explaining_cfg["explainer"]["force"] = True explaining_cfg["explainer"]["force"] = False
explaining_cfg["explanation_type"] = "phenomenon" explaining_cfg["explanation_type"] = "phenomenon"
explaining_cfg["model"] = {} explaining_cfg["model"] = {}
explaining_cfg["model"]["ckpt"] = model_kind explaining_cfg["model"]["ckpt"] = model_kind

View File

@ -17,5 +17,11 @@ def parse_args() -> argparse.Namespace:
type=int, type=int,
help="GPU ID if cuda available", help="GPU ID if cuda available",
) )
parser.add_argument(
"--force_fastforward",
type=bool,
help="It does not load file for elaredy existing files",
default=False,
)
return parser.parse_args() return parser.parse_args()

View File

@ -38,6 +38,9 @@ if __name__ == "__main__":
) )
makedirs(attack_path) makedirs(attack_path)
data_attack_path = os.path.join(attack_path, f"{index}.json") data_attack_path = os.path.join(attack_path, f"{index}.json")
if args.force_fastforward:
if os.path.exists(data_attack_path):
continue
data_attack = outline.get_attack( data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path attack=attack, item=item, path=data_attack_path
) )
@ -66,6 +69,9 @@ if __name__ == "__main__":
) )
makedirs(attack_path_) makedirs(attack_path_)
data_attack_path_ = os.path.join(attack_path_, f"{index}.json") data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
if args.force_fastforward:
if os.path.exists(data_attack_path_):
continue
attack_data = outline.get_attack( attack_data = outline.get_attack(
attack=attack, item=item, path=data_attack_path_ attack=attack, item=item, path=data_attack_path_
) )

View File

@ -1,8 +1,9 @@
CONFIG_DIR=$1 CONFIG_DIR=$1
MAX_JOBS=${2:-3} MAX_JOBS=${2:-3}
GPU=${3:-0} GPU=${3:-0}
SLEEP=${4:-1} FFF=${4:-False}
MAIN=${5:-main} SLEEP=${5:-1}
MAIN=${6:-main}
( (
trap 'kill 0' SIGINT trap 'kill 0' SIGINT
@ -11,7 +12,7 @@ MAIN=${5:-main}
if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then if [ "$CONFIG" != "$CONFIG_DIR/*.yaml" ]; then
((CUR_JOBS >= MAX_JOBS)) && wait -n ((CUR_JOBS >= MAX_JOBS)) && wait -n
export CUDA_VISIBLE_DEVICES=$GPU export CUDA_VISIBLE_DEVICES=$GPU
python3 $MAIN.py --explaining_cfg $CONFIG --gpu_id $GPU & python3 $MAIN.py --explaining_cfg $CONFIG --gpu_id $GPU --force-fastforward $FFF &
echo $CONFIG echo $CONFIG
sleep $SLEEP sleep $SLEEP
((++CUR_JOBS)) ((++CUR_JOBS))