Adding GPU selection features
This commit is contained in:
parent
644c90a41c
commit
7afa6c53f6
|
@ -13,8 +13,9 @@ def parse_args() -> argparse.Namespace:
|
||||||
help="The explaining configuration file path.",
|
help="The explaining configuration file path.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mark_done",
|
"--gpu_id",
|
||||||
action="store_true",
|
type=int,
|
||||||
help="Mark yaml as done after a job has finished.",
|
help="If GPUs available, on which GPU to run experiment",
|
||||||
|
default=0,
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
|
@ -122,7 +122,8 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
|
||||||
|
|
||||||
|
|
||||||
class ExplainingOutline(object):
|
class ExplainingOutline(object):
|
||||||
def __init__(self, explaining_cfg_path: str):
|
def __init__(self, explaining_cfg_path: str, gpu_id: int):
|
||||||
|
self.gpu_id = gpu_id
|
||||||
self.explaining_cfg_path = explaining_cfg_path
|
self.explaining_cfg_path = explaining_cfg_path
|
||||||
self.explaining_cfg = None
|
self.explaining_cfg = None
|
||||||
self.explainer_cfg_path = None
|
self.explainer_cfg_path = None
|
||||||
|
@ -167,7 +168,7 @@ class ExplainingOutline(object):
|
||||||
seed_everything(self.explaining_cfg.seed)
|
seed_everything(self.explaining_cfg.seed)
|
||||||
|
|
||||||
def load_model_to_hardware(self):
|
def load_model_to_hardware(self):
|
||||||
auto_select_device()
|
auto_select_device(gpu_id=self.gpu_id)
|
||||||
device = self.cfg.accelerator
|
device = self.cfg.accelerator
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
@ -745,13 +746,76 @@ class ExplainingOutline(object):
|
||||||
logging.info("Setting up experiment")
|
logging.info("Setting up experiment")
|
||||||
logging.info("Date and Time: %s", now)
|
logging.info("Date and Time: %s", now)
|
||||||
logging.info("Save experiment to %s", self.out_dir)
|
logging.info("Save experiment to %s", self.out_dir)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### ORIGINAL CONFIG FILE ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
logging.info(self.cfg)
|
logging.info(self.cfg)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### EXPLAINING CONFIG FILE ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
logging.info(self.explaining_cfg)
|
logging.info(self.explaining_cfg)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### EXPLAINER CONFIG FILE ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
logging.info(self.explainer_cfg)
|
logging.info(self.explainer_cfg)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### GNN ARCHITECTURE ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
logging.info(self.model)
|
logging.info(self.model)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### GNN METRICS ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
logging.info(obj_config_to_log(self.model_info))
|
logging.info(obj_config_to_log(self.model_info))
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### METRICS (EXPLANATION PROCESS) ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
for metric in self.metrics + self.attacks:
|
for metric in self.metrics + self.attacks:
|
||||||
logging.info(obj_config_to_str(metric))
|
logging.info(obj_config_to_str(metric))
|
||||||
for threshold_conf in self.thresholds_configs:
|
for threshold_conf in self.thresholds_configs:
|
||||||
logging.info(obj_config_to_str(threshold_conf))
|
logging.info(obj_config_to_str(threshold_conf))
|
||||||
logging.info("Proceeding to explanations..")
|
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"############################### RUNNING EXPLANATIONS ... ###############################"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"####################################################################################"
|
||||||
|
)
|
||||||
|
|
2
main.py
2
main.py
|
@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||||
|
|
||||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue