Adding GPU selection features
This commit is contained in:
parent
644c90a41c
commit
7afa6c53f6
@ -13,8 +13,9 @@ def parse_args() -> argparse.Namespace:
|
||||
help="The explaining configuration file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mark_done",
|
||||
action="store_true",
|
||||
help="Mark yaml as done after a job has finished.",
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
help="If GPUs available, on which GPU to run experiment",
|
||||
default=0,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
@ -122,7 +122,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
|
||||
|
||||
|
||||
class ExplainingOutline(object):
|
||||
def __init__(self, explaining_cfg_path: str):
|
||||
def __init__(self, explaining_cfg_path: str, gpu_id: int):
|
||||
self.gpu_id = gpu_id
|
||||
self.explaining_cfg_path = explaining_cfg_path
|
||||
self.explaining_cfg = None
|
||||
self.explainer_cfg_path = None
|
||||
@ -167,7 +168,7 @@ class ExplainingOutline(object):
|
||||
seed_everything(self.explaining_cfg.seed)
|
||||
|
||||
def load_model_to_hardware(self):
|
||||
auto_select_device()
|
||||
auto_select_device(gpu_id=self.gpu_id)
|
||||
device = self.cfg.accelerator
|
||||
self.model = self.model.to(device)
|
||||
|
||||
@ -745,13 +746,76 @@ class ExplainingOutline(object):
|
||||
logging.info("Setting up experiment")
|
||||
logging.info("Date and Time: %s", now)
|
||||
logging.info("Save experiment to %s", self.out_dir)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### ORIGINAL CONFIG FILE ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(self.cfg)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### EXPLAINING CONFIG FILE ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(self.explaining_cfg)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### EXPLAINER CONFIG FILE ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(self.explainer_cfg)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### GNN ARCHITECTURE ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(self.model)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### GNN METRICS ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(obj_config_to_log(self.model_info))
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### METRICS (EXPLANATION PROCESS) ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
for metric in self.metrics + self.attacks:
|
||||
logging.info(obj_config_to_str(metric))
|
||||
for threshold_conf in self.thresholds_configs:
|
||||
logging.info(obj_config_to_str(threshold_conf))
|
||||
logging.info("Proceeding to explanations..")
|
||||
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
logging.info(
|
||||
"############################### RUNNING EXPLANATIONS ... ###############################"
|
||||
)
|
||||
logging.info(
|
||||
"####################################################################################"
|
||||
)
|
||||
|
2
main.py
2
main.py
@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||
|
||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user