Adding GPU selection features

This commit is contained in:
araison 2023-01-09 12:33:44 +01:00
parent 644c90a41c
commit 7afa6c53f6
3 changed files with 72 additions and 7 deletions

View File

@ -13,8 +13,9 @@ def parse_args() -> argparse.Namespace:
help="The explaining configuration file path.",
)
parser.add_argument(
"--mark_done",
action="store_true",
help="Mark yaml as done after a job has finished.",
"--gpu_id",
type=int,
help="If GPUs available, on which GPU to run experiment",
default=0,
)
return parser.parse_args()

View File

@ -122,7 +122,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str):
def __init__(self, explaining_cfg_path: str, gpu_id: int):
self.gpu_id = gpu_id
self.explaining_cfg_path = explaining_cfg_path
self.explaining_cfg = None
self.explainer_cfg_path = None
@ -167,7 +168,7 @@ class ExplainingOutline(object):
seed_everything(self.explaining_cfg.seed)
def load_model_to_hardware(self):
auto_select_device()
auto_select_device(gpu_id=self.gpu_id)
device = self.cfg.accelerator
self.model = self.model.to(device)
@ -745,13 +746,76 @@ class ExplainingOutline(object):
logging.info("Setting up experiment")
logging.info("Date and Time: %s", now)
logging.info("Save experiment to %s", self.out_dir)
logging.info(
"####################################################################################"
)
logging.info(
"############################### ORIGINAL CONFIG FILE ###############################"
)
logging.info(
"####################################################################################"
)
logging.info(self.cfg)
logging.info(
"####################################################################################"
)
logging.info(
"############################### EXPLAINING CONFIG FILE ###############################"
)
logging.info(
"####################################################################################"
)
logging.info(self.explaining_cfg)
logging.info(
"####################################################################################"
)
logging.info(
"############################### EXPLAINER CONFIG FILE ###############################"
)
logging.info(
"####################################################################################"
)
logging.info(self.explainer_cfg)
logging.info(
"####################################################################################"
)
logging.info(
"############################### GNN ARCHITECTURE ###############################"
)
logging.info(
"####################################################################################"
)
logging.info(self.model)
logging.info(
"####################################################################################"
)
logging.info(
"############################### GNN METRICS ###############################"
)
logging.info(
"####################################################################################"
)
logging.info(obj_config_to_log(self.model_info))
logging.info(
"####################################################################################"
)
logging.info(
"############################### METRICS (EXPLANATION PROCESS) ###############################"
)
logging.info(
"####################################################################################"
)
for metric in self.metrics + self.attacks:
logging.info(obj_config_to_str(metric))
for threshold_conf in self.thresholds_configs:
logging.info(obj_config_to_str(threshold_conf))
logging.info("Proceeding to explanations..")
logging.info(
"####################################################################################"
)
logging.info(
"############################### RUNNING EXPLANATIONS ... ###############################"
)
logging.info(
"####################################################################################"
)

View File

@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file)
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))