diff --git a/explaining_framework/utils/explaining/cmd_args.py b/explaining_framework/utils/explaining/cmd_args.py index 388d231..b2631cd 100644 --- a/explaining_framework/utils/explaining/cmd_args.py +++ b/explaining_framework/utils/explaining/cmd_args.py @@ -13,8 +13,9 @@ def parse_args() -> argparse.Namespace: help="The explaining configuration file path.", ) parser.add_argument( - "--mark_done", - action="store_true", - help="Mark yaml as done after a job has finished.", + "--gpu_id", + type=int, + help="If GPUs available, on which GPU to run experiment", + default=0, ) return parser.parse_args() diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 1594dd2..0ad7a70 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -122,7 +122,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"] class ExplainingOutline(object): - def __init__(self, explaining_cfg_path: str): + def __init__(self, explaining_cfg_path: str, gpu_id: int): + self.gpu_id = gpu_id self.explaining_cfg_path = explaining_cfg_path self.explaining_cfg = None self.explainer_cfg_path = None @@ -167,7 +168,7 @@ class ExplainingOutline(object): seed_everything(self.explaining_cfg.seed) def load_model_to_hardware(self): - auto_select_device() + auto_select_device(gpu_id=self.gpu_id) device = self.cfg.accelerator self.model = self.model.to(device) @@ -745,13 +746,76 @@ class ExplainingOutline(object): logging.info("Setting up experiment") logging.info("Date and Time: %s", now) logging.info("Save experiment to %s", self.out_dir) + logging.info( + "####################################################################################" + ) + logging.info( + "############################### ORIGINAL CONFIG FILE ###############################" + ) + logging.info( + "####################################################################################" + ) logging.info(self.cfg) + logging.info( + "####################################################################################" + ) + logging.info( + "############################### EXPLAINING CONFIG FILE ###############################" + ) + logging.info( + "####################################################################################" + ) logging.info(self.explaining_cfg) + logging.info( + "####################################################################################" + ) + logging.info( + "############################### EXPLAINER CONFIG FILE ###############################" + ) + logging.info( + "####################################################################################" + ) logging.info(self.explainer_cfg) + logging.info( + "####################################################################################" + ) + logging.info( + "############################### GNN ARCHITECTURE ###############################" + ) + logging.info( + "####################################################################################" + ) logging.info(self.model) + logging.info( + "####################################################################################" + ) + logging.info( + "############################### GNN METRICS ###############################" + ) + logging.info( + "####################################################################################" + ) logging.info(obj_config_to_log(self.model_info)) + logging.info( + "####################################################################################" + ) + logging.info( + "############################### METRICS (EXPLANATION PROCESS) ###############################" + ) + logging.info( + "####################################################################################" + ) for metric in self.metrics + self.attacks: logging.info(obj_config_to_str(metric)) for threshold_conf in self.thresholds_configs: logging.info(obj_config_to_str(threshold_conf)) - logging.info("Proceeding to explanations..") + + logging.info( + "####################################################################################" + ) + logging.info( + "############################### RUNNING EXPLANATIONS ... ###############################" + ) + logging.info( + "####################################################################################" + ) diff --git a/main.py b/main.py index bcdbda7..b361fe2 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, if __name__ == "__main__": args = parse_args() - outline = ExplainingOutline(args.explaining_cfg_file) + outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))