Removing GPU selection features

This commit is contained in:
araison 2023-01-09 13:08:42 +01:00
parent 7afa6c53f6
commit c010e08e9b
4 changed files with 24 additions and 32 deletions

View File

@ -12,10 +12,4 @@ def parse_args() -> argparse.Namespace:
required=True,
help="The explaining configuration file path.",
)
parser.add_argument(
"--gpu_id",
type=int,
help="If GPUs available, on which GPU to run experiment",
default=0,
)
return parser.parse_args()

View File

@ -122,8 +122,7 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str, gpu_id: int):
self.gpu_id = gpu_id
def __init__(self, explaining_cfg_path: str):
self.explaining_cfg_path = explaining_cfg_path
self.explaining_cfg = None
self.explainer_cfg_path = None
@ -168,7 +167,7 @@ class ExplainingOutline(object):
seed_everything(self.explaining_cfg.seed)
def load_model_to_hardware(self):
auto_select_device(gpu_id=self.gpu_id)
auto_select_device()
device = self.cfg.accelerator
self.model = self.model.to(device)
@ -567,18 +566,18 @@ class ExplainingOutline(object):
explanation = _get_explanation(self.explainer, item)
if explanation is None:
logging.warning(
" EXP || Generated; Path %s; FAILED",
" EXP::Generated; Path %s; FAILED",
(path),
)
else:
logging.debug(
"EXP || Generated; Path %s; SUCCEEDED",
"EXP::Generated; Path %s; SUCCEEDED",
(path),
)
else:
explanation = _load_explanation(path)
logging.debug(
"EXP || Loaded; Path %s; SUCCEEDED",
"EXP::Loaded; Path %s; SUCCEEDED",
(path),
)
explanation = explanation.to(self.cfg.accelerator)
@ -587,7 +586,7 @@ class ExplainingOutline(object):
get_pred(self.explainer, explanation)
_save_explanation(explanation, path)
logging.debug(
"EXP || Generated; Path %s; SUCCEEDED",
"EXP::Generated; Path %s; SUCCEEDED",
(path),
)
@ -598,14 +597,14 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item)
logging.debug(
"ADJUST || Generated; Path %s; SUCCEEDED",
"ADJUST::Generated; Path %s; SUCCEEDED",
(path),
)
else:
exp_adjust = _load_explanation(path)
logging.debug(
"ADJUST || Loaded; Path %s; SUCCEEDED",
"ADJUST::Loaded; Path %s; SUCCEEDED",
(path),
)
@ -614,7 +613,7 @@ class ExplainingOutline(object):
get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path)
logging.debug(
"ADJUST || Generated; Path %s; SUCCEEDED",
"ADJUST::Generated; Path %s; SUCCEEDED",
(path),
)
return exp_adjust
@ -624,13 +623,13 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item)
logging.debug(
"THRESHOLD || Generated; Path %s; SUCCEEDED",
"THRESHOLD::Generated; Path %s; SUCCEEDED",
(path),
)
else:
exp_threshold = _load_explanation(path)
logging.debug(
"THRESHOLD || Loaded; Path %s; SUCCEEDED",
"THRESHOLD::Loaded; Path %s; SUCCEEDED",
(path),
)
else:
@ -638,12 +637,12 @@ class ExplainingOutline(object):
get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path)
logging.debug(
"THRESHOLD || Generated; Path %s; SUCCEEDED",
"THRESHOLD::Generated; Path %s; SUCCEEDED",
(path),
)
if is_empty_graph(exp_threshold):
logging.warning(
"THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED",
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
(path),
)
return None
@ -654,13 +653,13 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item)
logging.debug(
"METRIC || Generated; Path %s; SUCCEEDED",
"METRIC::Generated; Path %s; SUCCEEDED",
(path),
)
else:
out_metric = read_json(path)
logging.debug(
"METRIC || Loaded; Path %s; SUCCEEDED",
"METRIC::Loaded; Path %s; SUCCEEDED",
(path),
)
else:
@ -669,12 +668,12 @@ class ExplainingOutline(object):
write_json(data, path)
if out_metric is None:
logging.debug(
"METRIC || Generated; Path %s; FAILED",
"METRIC::Generated; Path %s; FAILED",
(path),
)
else:
logging.debug(
"METRIC || Generated; Path %s; SUCCEEDED",
"METRIC::Generated; Path %s; SUCCEEDED",
(path),
)
return out_metric
@ -694,21 +693,21 @@ class ExplainingOutline(object):
if self.explaining_cfg.explainer.force:
data_attack = attack.get_attacked_prediction(item)
logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED",
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
)
else:
data_attack = _load_explanation(path)
logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED",
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
)
else:
data_attack = attack.get_attacked_prediction(item)
_save_explanation(data_attack, path)
logging.debug(
"ATTACK || Generated %s; Path %s; SUCCEEDED",
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
)
return data_attack

View File

@ -4,7 +4,6 @@ import os
import sys
import yaml
from explaining_framework.config.explaining_config import explaining_cfg
@ -78,11 +77,11 @@ def obj_config_to_log(obj) -> str:
if isinstance(obj, dict):
config = get_dict_config(obj)
for k, v in config.items():
logging.info(f"{k} : {v}")
logging.info(f"{k}={v}")
else:
config = get_dict_config(obj.__dict__)
for k, v in config.items():
logging.info(f"{k} : {v}")
logging.info(f"{k}={v}")
def set_printing(logger_path):
@ -93,7 +92,7 @@ def set_printing(logger_path):
logging.root.handlers = []
logging_cfg = {
"level": logging.INFO,
"format": "%(asctime)s:%(levelname)s:%(message)s",
"format": "%(asctime)s::%(levelname)s::%(message)s",
}
h_file = logging.FileHandler(logger_path)
h_stdout = logging.StreamHandler(sys.stdout)

View File

@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
outline = ExplainingOutline(args.explaining_cfg_file)
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))