Removing GPU selection features
This commit is contained in:
parent
7afa6c53f6
commit
c010e08e9b
|
@ -12,10 +12,4 @@ def parse_args() -> argparse.Namespace:
|
||||||
required=True,
|
required=True,
|
||||||
help="The explaining configuration file path.",
|
help="The explaining configuration file path.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--gpu_id",
|
|
||||||
type=int,
|
|
||||||
help="If GPUs available, on which GPU to run experiment",
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
|
@ -122,8 +122,7 @@ all_threshold_type = ["topk_hard", "hard", "topk"]
|
||||||
|
|
||||||
|
|
||||||
class ExplainingOutline(object):
|
class ExplainingOutline(object):
|
||||||
def __init__(self, explaining_cfg_path: str, gpu_id: int):
|
def __init__(self, explaining_cfg_path: str):
|
||||||
self.gpu_id = gpu_id
|
|
||||||
self.explaining_cfg_path = explaining_cfg_path
|
self.explaining_cfg_path = explaining_cfg_path
|
||||||
self.explaining_cfg = None
|
self.explaining_cfg = None
|
||||||
self.explainer_cfg_path = None
|
self.explainer_cfg_path = None
|
||||||
|
@ -168,7 +167,7 @@ class ExplainingOutline(object):
|
||||||
seed_everything(self.explaining_cfg.seed)
|
seed_everything(self.explaining_cfg.seed)
|
||||||
|
|
||||||
def load_model_to_hardware(self):
|
def load_model_to_hardware(self):
|
||||||
auto_select_device(gpu_id=self.gpu_id)
|
auto_select_device()
|
||||||
device = self.cfg.accelerator
|
device = self.cfg.accelerator
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
@ -567,18 +566,18 @@ class ExplainingOutline(object):
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
if explanation is None:
|
if explanation is None:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
" EXP || Generated; Path %s; FAILED",
|
" EXP::Generated; Path %s; FAILED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"EXP || Generated; Path %s; SUCCEEDED",
|
"EXP::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
explanation = _load_explanation(path)
|
explanation = _load_explanation(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"EXP || Loaded; Path %s; SUCCEEDED",
|
"EXP::Loaded; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
explanation = explanation.to(self.cfg.accelerator)
|
explanation = explanation.to(self.cfg.accelerator)
|
||||||
|
@ -587,7 +586,7 @@ class ExplainingOutline(object):
|
||||||
get_pred(self.explainer, explanation)
|
get_pred(self.explainer, explanation)
|
||||||
_save_explanation(explanation, path)
|
_save_explanation(explanation, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"EXP || Generated; Path %s; SUCCEEDED",
|
"EXP::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -598,14 +597,14 @@ class ExplainingOutline(object):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
exp_adjust = adjust.forward(item)
|
exp_adjust = adjust.forward(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ADJUST || Generated; Path %s; SUCCEEDED",
|
"ADJUST::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
exp_adjust = _load_explanation(path)
|
exp_adjust = _load_explanation(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ADJUST || Loaded; Path %s; SUCCEEDED",
|
"ADJUST::Loaded; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -614,7 +613,7 @@ class ExplainingOutline(object):
|
||||||
get_pred(self.explainer, exp_adjust)
|
get_pred(self.explainer, exp_adjust)
|
||||||
_save_explanation(exp_adjust, path)
|
_save_explanation(exp_adjust, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ADJUST || Generated; Path %s; SUCCEEDED",
|
"ADJUST::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
return exp_adjust
|
return exp_adjust
|
||||||
|
@ -624,13 +623,13 @@ class ExplainingOutline(object):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
exp_threshold = self.explainer._post_process(item)
|
exp_threshold = self.explainer._post_process(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"THRESHOLD || Generated; Path %s; SUCCEEDED",
|
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exp_threshold = _load_explanation(path)
|
exp_threshold = _load_explanation(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"THRESHOLD || Loaded; Path %s; SUCCEEDED",
|
"THRESHOLD::Loaded; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -638,12 +637,12 @@ class ExplainingOutline(object):
|
||||||
get_pred(self.explainer, exp_threshold)
|
get_pred(self.explainer, exp_threshold)
|
||||||
_save_explanation(exp_threshold, path)
|
_save_explanation(exp_threshold, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"THRESHOLD || Generated; Path %s; SUCCEEDED",
|
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
if is_empty_graph(exp_threshold):
|
if is_empty_graph(exp_threshold):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED",
|
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
@ -654,13 +653,13 @@ class ExplainingOutline(object):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
out_metric = metric.forward(item)
|
out_metric = metric.forward(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC || Generated; Path %s; SUCCEEDED",
|
"METRIC::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_metric = read_json(path)
|
out_metric = read_json(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC || Loaded; Path %s; SUCCEEDED",
|
"METRIC::Loaded; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -669,12 +668,12 @@ class ExplainingOutline(object):
|
||||||
write_json(data, path)
|
write_json(data, path)
|
||||||
if out_metric is None:
|
if out_metric is None:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC || Generated; Path %s; FAILED",
|
"METRIC::Generated; Path %s; FAILED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC || Generated; Path %s; SUCCEEDED",
|
"METRIC::Generated; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
return out_metric
|
return out_metric
|
||||||
|
@ -694,21 +693,21 @@ class ExplainingOutline(object):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
data_attack = _load_explanation(path)
|
data_attack = _load_explanation(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
_save_explanation(data_attack, path)
|
_save_explanation(data_attack, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
return data_attack
|
return data_attack
|
||||||
|
|
|
@ -4,7 +4,6 @@ import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from explaining_framework.config.explaining_config import explaining_cfg
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,11 +77,11 @@ def obj_config_to_log(obj) -> str:
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
config = get_dict_config(obj)
|
config = get_dict_config(obj)
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
logging.info(f"{k} : {v}")
|
logging.info(f"{k}={v}")
|
||||||
else:
|
else:
|
||||||
config = get_dict_config(obj.__dict__)
|
config = get_dict_config(obj.__dict__)
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
logging.info(f"{k} : {v}")
|
logging.info(f"{k}={v}")
|
||||||
|
|
||||||
|
|
||||||
def set_printing(logger_path):
|
def set_printing(logger_path):
|
||||||
|
@ -93,7 +92,7 @@ def set_printing(logger_path):
|
||||||
logging.root.handlers = []
|
logging.root.handlers = []
|
||||||
logging_cfg = {
|
logging_cfg = {
|
||||||
"level": logging.INFO,
|
"level": logging.INFO,
|
||||||
"format": "%(asctime)s:%(levelname)s:%(message)s",
|
"format": "%(asctime)s::%(levelname)s::%(message)s",
|
||||||
}
|
}
|
||||||
h_file = logging.FileHandler(logger_path)
|
h_file = logging.FileHandler(logger_path)
|
||||||
h_stdout = logging.StreamHandler(sys.stdout)
|
h_stdout = logging.StreamHandler(sys.stdout)
|
||||||
|
|
2
main.py
2
main.py
|
@ -25,7 +25,7 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
|
|
||||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue