Removing GPU selection features

This commit is contained in:
araison 2023-01-09 13:08:42 +01:00
parent 7afa6c53f6
commit c010e08e9b
4 changed files with 24 additions and 32 deletions

View File

@ -12,10 +12,4 @@ def parse_args() -> argparse.Namespace:
required=True, required=True,
help="The explaining configuration file path.", help="The explaining configuration file path.",
) )
parser.add_argument(
"--gpu_id",
type=int,
help="If GPUs available, on which GPU to run experiment",
default=0,
)
return parser.parse_args() return parser.parse_args()

View File

@ -122,8 +122,7 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
class ExplainingOutline(object): class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str, gpu_id: int): def __init__(self, explaining_cfg_path: str):
self.gpu_id = gpu_id
self.explaining_cfg_path = explaining_cfg_path self.explaining_cfg_path = explaining_cfg_path
self.explaining_cfg = None self.explaining_cfg = None
self.explainer_cfg_path = None self.explainer_cfg_path = None
@ -168,7 +167,7 @@ class ExplainingOutline(object):
seed_everything(self.explaining_cfg.seed) seed_everything(self.explaining_cfg.seed)
def load_model_to_hardware(self): def load_model_to_hardware(self):
auto_select_device(gpu_id=self.gpu_id) auto_select_device()
device = self.cfg.accelerator device = self.cfg.accelerator
self.model = self.model.to(device) self.model = self.model.to(device)
@ -567,18 +566,18 @@ class ExplainingOutline(object):
explanation = _get_explanation(self.explainer, item) explanation = _get_explanation(self.explainer, item)
if explanation is None: if explanation is None:
logging.warning( logging.warning(
" EXP || Generated; Path %s; FAILED", " EXP::Generated; Path %s; FAILED",
(path), (path),
) )
else: else:
logging.debug( logging.debug(
"EXP || Generated; Path %s; SUCCEEDED", "EXP::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
explanation = _load_explanation(path) explanation = _load_explanation(path)
logging.debug( logging.debug(
"EXP || Loaded; Path %s; SUCCEEDED", "EXP::Loaded; Path %s; SUCCEEDED",
(path), (path),
) )
explanation = explanation.to(self.cfg.accelerator) explanation = explanation.to(self.cfg.accelerator)
@ -587,7 +586,7 @@ class ExplainingOutline(object):
get_pred(self.explainer, explanation) get_pred(self.explainer, explanation)
_save_explanation(explanation, path) _save_explanation(explanation, path)
logging.debug( logging.debug(
"EXP || Generated; Path %s; SUCCEEDED", "EXP::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
@ -598,14 +597,14 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item) exp_adjust = adjust.forward(item)
logging.debug( logging.debug(
"ADJUST || Generated; Path %s; SUCCEEDED", "ADJUST::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
exp_adjust = _load_explanation(path) exp_adjust = _load_explanation(path)
logging.debug( logging.debug(
"ADJUST || Loaded; Path %s; SUCCEEDED", "ADJUST::Loaded; Path %s; SUCCEEDED",
(path), (path),
) )
@ -614,7 +613,7 @@ class ExplainingOutline(object):
get_pred(self.explainer, exp_adjust) get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path) _save_explanation(exp_adjust, path)
logging.debug( logging.debug(
"ADJUST || Generated; Path %s; SUCCEEDED", "ADJUST::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
return exp_adjust return exp_adjust
@ -624,13 +623,13 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item) exp_threshold = self.explainer._post_process(item)
logging.debug( logging.debug(
"THRESHOLD || Generated; Path %s; SUCCEEDED", "THRESHOLD::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
exp_threshold = _load_explanation(path) exp_threshold = _load_explanation(path)
logging.debug( logging.debug(
"THRESHOLD || Loaded; Path %s; SUCCEEDED", "THRESHOLD::Loaded; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
@ -638,12 +637,12 @@ class ExplainingOutline(object):
get_pred(self.explainer, exp_threshold) get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path) _save_explanation(exp_threshold, path)
logging.debug( logging.debug(
"THRESHOLD || Generated; Path %s; SUCCEEDED", "THRESHOLD::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
if is_empty_graph(exp_threshold): if is_empty_graph(exp_threshold):
logging.warning( logging.warning(
"THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED", "THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
(path), (path),
) )
return None return None
@ -654,13 +653,13 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item) out_metric = metric.forward(item)
logging.debug( logging.debug(
"METRIC || Generated; Path %s; SUCCEEDED", "METRIC::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
out_metric = read_json(path) out_metric = read_json(path)
logging.debug( logging.debug(
"METRIC || Loaded; Path %s; SUCCEEDED", "METRIC::Loaded; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
@ -669,12 +668,12 @@ class ExplainingOutline(object):
write_json(data, path) write_json(data, path)
if out_metric is None: if out_metric is None:
logging.debug( logging.debug(
"METRIC || Generated; Path %s; FAILED", "METRIC::Generated; Path %s; FAILED",
(path), (path),
) )
else: else:
logging.debug( logging.debug(
"METRIC || Generated; Path %s; SUCCEEDED", "METRIC::Generated; Path %s; SUCCEEDED",
(path), (path),
) )
return out_metric return out_metric
@ -694,21 +693,21 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
data_attack = attack.get_attacked_prediction(item) data_attack = attack.get_attacked_prediction(item)
logging.debug( logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED", "ATTACK::Generated %s; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
data_attack = _load_explanation(path) data_attack = _load_explanation(path)
logging.debug( logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED", "ATTACK::Generated %s; Path %s; SUCCEEDED",
(path), (path),
) )
else: else:
data_attack = attack.get_attacked_prediction(item) data_attack = attack.get_attacked_prediction(item)
_save_explanation(data_attack, path) _save_explanation(data_attack, path)
logging.debug( logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED", "ATTACK::Generated %s; Path %s; SUCCEEDED",
(path), (path),
) )
return data_attack return data_attack

View File

@ -4,7 +4,6 @@ import os
import sys import sys
import yaml import yaml
from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.config.explaining_config import explaining_cfg
@ -78,11 +77,11 @@ def obj_config_to_log(obj) -> str:
if isinstance(obj, dict): if isinstance(obj, dict):
config = get_dict_config(obj) config = get_dict_config(obj)
for k, v in config.items(): for k, v in config.items():
logging.info(f"{k} : {v}") logging.info(f"{k}={v}")
else: else:
config = get_dict_config(obj.__dict__) config = get_dict_config(obj.__dict__)
for k, v in config.items(): for k, v in config.items():
logging.info(f"{k} : {v}") logging.info(f"{k}={v}")
def set_printing(logger_path): def set_printing(logger_path):
@ -93,7 +92,7 @@ def set_printing(logger_path):
logging.root.handlers = [] logging.root.handlers = []
logging_cfg = { logging_cfg = {
"level": logging.INFO, "level": logging.INFO,
"format": "%(asctime)s:%(levelname)s:%(message)s", "format": "%(asctime)s::%(levelname)s::%(message)s",
} }
h_file = logging.FileHandler(logger_path) h_file = logging.FileHandler(logger_path)
h_stdout = logging.StreamHandler(sys.stdout) h_stdout = logging.StreamHandler(sys.stdout)

View File

@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) outline = ExplainingOutline(args.explaining_cfg_file)
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))