Removing GPU selection features
This commit is contained in:
parent
7afa6c53f6
commit
c010e08e9b
|
@ -12,10 +12,4 @@ def parse_args() -> argparse.Namespace:
|
|||
required=True,
|
||||
help="The explaining configuration file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
help="If GPUs available, on which GPU to run experiment",
|
||||
default=0,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
|
|
@ -122,8 +122,7 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
|
|||
|
||||
|
||||
class ExplainingOutline(object):
|
||||
def __init__(self, explaining_cfg_path: str, gpu_id: int):
|
||||
self.gpu_id = gpu_id
|
||||
def __init__(self, explaining_cfg_path: str):
|
||||
self.explaining_cfg_path = explaining_cfg_path
|
||||
self.explaining_cfg = None
|
||||
self.explainer_cfg_path = None
|
||||
|
@ -168,7 +167,7 @@ class ExplainingOutline(object):
|
|||
seed_everything(self.explaining_cfg.seed)
|
||||
|
||||
def load_model_to_hardware(self):
|
||||
auto_select_device(gpu_id=self.gpu_id)
|
||||
auto_select_device()
|
||||
device = self.cfg.accelerator
|
||||
self.model = self.model.to(device)
|
||||
|
||||
|
@ -567,18 +566,18 @@ class ExplainingOutline(object):
|
|||
explanation = _get_explanation(self.explainer, item)
|
||||
if explanation is None:
|
||||
logging.warning(
|
||||
" EXP || Generated; Path %s; FAILED",
|
||||
" EXP::Generated; Path %s; FAILED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
"EXP || Generated; Path %s; SUCCEEDED",
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
explanation = _load_explanation(path)
|
||||
logging.debug(
|
||||
"EXP || Loaded; Path %s; SUCCEEDED",
|
||||
"EXP::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
explanation = explanation.to(self.cfg.accelerator)
|
||||
|
@ -587,7 +586,7 @@ class ExplainingOutline(object):
|
|||
get_pred(self.explainer, explanation)
|
||||
_save_explanation(explanation, path)
|
||||
logging.debug(
|
||||
"EXP || Generated; Path %s; SUCCEEDED",
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
|
||||
|
@ -598,14 +597,14 @@ class ExplainingOutline(object):
|
|||
if self.explaining_cfg.explainer.force:
|
||||
exp_adjust = adjust.forward(item)
|
||||
logging.debug(
|
||||
"ADJUST || Generated; Path %s; SUCCEEDED",
|
||||
"ADJUST::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
|
||||
else:
|
||||
exp_adjust = _load_explanation(path)
|
||||
logging.debug(
|
||||
"ADJUST || Loaded; Path %s; SUCCEEDED",
|
||||
"ADJUST::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
|
||||
|
@ -614,7 +613,7 @@ class ExplainingOutline(object):
|
|||
get_pred(self.explainer, exp_adjust)
|
||||
_save_explanation(exp_adjust, path)
|
||||
logging.debug(
|
||||
"ADJUST || Generated; Path %s; SUCCEEDED",
|
||||
"ADJUST::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
return exp_adjust
|
||||
|
@ -624,13 +623,13 @@ class ExplainingOutline(object):
|
|||
if self.explaining_cfg.explainer.force:
|
||||
exp_threshold = self.explainer._post_process(item)
|
||||
logging.debug(
|
||||
"THRESHOLD || Generated; Path %s; SUCCEEDED",
|
||||
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
exp_threshold = _load_explanation(path)
|
||||
logging.debug(
|
||||
"THRESHOLD || Loaded; Path %s; SUCCEEDED",
|
||||
"THRESHOLD::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
|
@ -638,12 +637,12 @@ class ExplainingOutline(object):
|
|||
get_pred(self.explainer, exp_threshold)
|
||||
_save_explanation(exp_threshold, path)
|
||||
logging.debug(
|
||||
"THRESHOLD || Generated; Path %s; SUCCEEDED",
|
||||
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
if is_empty_graph(exp_threshold):
|
||||
logging.warning(
|
||||
"THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED",
|
||||
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
|
||||
(path),
|
||||
)
|
||||
return None
|
||||
|
@ -654,13 +653,13 @@ class ExplainingOutline(object):
|
|||
if self.explaining_cfg.explainer.force:
|
||||
out_metric = metric.forward(item)
|
||||
logging.debug(
|
||||
"METRIC || Generated; Path %s; SUCCEEDED",
|
||||
"METRIC::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
out_metric = read_json(path)
|
||||
logging.debug(
|
||||
"METRIC || Loaded; Path %s; SUCCEEDED",
|
||||
"METRIC::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
|
@ -669,12 +668,12 @@ class ExplainingOutline(object):
|
|||
write_json(data, path)
|
||||
if out_metric is None:
|
||||
logging.debug(
|
||||
"METRIC || Generated; Path %s; FAILED",
|
||||
"METRIC::Generated; Path %s; FAILED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
"METRIC || Generated; Path %s; SUCCEEDED",
|
||||
"METRIC::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
return out_metric
|
||||
|
@ -694,21 +693,21 @@ class ExplainingOutline(object):
|
|||
if self.explaining_cfg.explainer.force:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
logging.debug(
|
||||
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
|
||||
else:
|
||||
data_attack = _load_explanation(path)
|
||||
logging.debug(
|
||||
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
_save_explanation(data_attack, path)
|
||||
logging.debug(
|
||||
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
return data_attack
|
||||
|
|
|
@ -4,7 +4,6 @@ import os
|
|||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
|
||||
|
||||
|
@ -78,11 +77,11 @@ def obj_config_to_log(obj) -> str:
|
|||
if isinstance(obj, dict):
|
||||
config = get_dict_config(obj)
|
||||
for k, v in config.items():
|
||||
logging.info(f"{k} : {v}")
|
||||
logging.info(f"{k}={v}")
|
||||
else:
|
||||
config = get_dict_config(obj.__dict__)
|
||||
for k, v in config.items():
|
||||
logging.info(f"{k} : {v}")
|
||||
logging.info(f"{k}={v}")
|
||||
|
||||
|
||||
def set_printing(logger_path):
|
||||
|
@ -93,7 +92,7 @@ def set_printing(logger_path):
|
|||
logging.root.handlers = []
|
||||
logging_cfg = {
|
||||
"level": logging.INFO,
|
||||
"format": "%(asctime)s:%(levelname)s:%(message)s",
|
||||
"format": "%(asctime)s::%(levelname)s::%(message)s",
|
||||
}
|
||||
h_file = logging.FileHandler(logger_path)
|
||||
h_stdout = logging.StreamHandler(sys.stdout)
|
||||
|
|
2
main.py
2
main.py
|
@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||
|
||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||
|
||||
|
|
Loading…
Reference in New Issue