Reformating, adding logging and progress bar
This commit is contained in:
parent
fbc685503c
commit
02bdbdc6ca
|
@ -1,11 +1,15 @@
|
||||||
import copy
|
import copy
|
||||||
|
import datetime
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
from scgnn.scgnn import SCGNN
|
||||||
from torch_geometric import seed_everything
|
from torch_geometric import seed_everything
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
@ -35,9 +39,12 @@ from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.explanation.io import (
|
from explaining_framework.utils.explanation.io import (
|
||||||
_get_explanation, _load_explanation, _save_explanation,
|
_get_explanation, _load_explanation, _save_explanation,
|
||||||
explanation_verification, get_pred)
|
explanation_verification, get_pred, is_empty_graph)
|
||||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
read_json, write_json, write_yaml)
|
obj_config_to_log,
|
||||||
|
obj_config_to_str, read_json,
|
||||||
|
set_printing, write_json,
|
||||||
|
write_yaml)
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -155,6 +162,7 @@ class ExplainingOutline(object):
|
||||||
self.load_adjust()
|
self.load_adjust()
|
||||||
self.load_threshold()
|
self.load_threshold()
|
||||||
self.load_graphstat()
|
self.load_graphstat()
|
||||||
|
self.setup_experiment()
|
||||||
|
|
||||||
seed_everything(self.explaining_cfg.seed)
|
seed_everything(self.explaining_cfg.seed)
|
||||||
|
|
||||||
|
@ -168,7 +176,8 @@ class ExplainingOutline(object):
|
||||||
self.load_dataset()
|
self.load_dataset()
|
||||||
try:
|
try:
|
||||||
item = next(self.dataset)
|
item = next(self.dataset)
|
||||||
item = item.to(cfg.accelerator)
|
device = self.cfg.accelerator
|
||||||
|
item = item.to(device)
|
||||||
return item
|
return item
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
@ -555,49 +564,118 @@ class ExplainingOutline(object):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
|
if explanation is None:
|
||||||
|
logging.warning(
|
||||||
|
" EXP || Generated; Path %s; FAILED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.debug(
|
||||||
|
"EXP || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
explanation = _load_explanation(path)
|
explanation = _load_explanation(path)
|
||||||
|
logging.debug(
|
||||||
|
"EXP || Loaded; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
explanation = explanation.to(self.cfg.accelerator)
|
explanation = explanation.to(self.cfg.accelerator)
|
||||||
else:
|
else:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
get_pred(self.explainer, explanation)
|
get_pred(self.explainer, explanation)
|
||||||
_save_explanation(explanation, path)
|
_save_explanation(explanation, path)
|
||||||
|
logging.debug(
|
||||||
|
"EXP || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
exp_adjust = adjust.forward(item)
|
exp_adjust = adjust.forward(item)
|
||||||
|
logging.debug(
|
||||||
|
"ADJUST || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
exp_adjust = _load_explanation(path)
|
exp_adjust = _load_explanation(path)
|
||||||
|
logging.debug(
|
||||||
|
"ADJUST || Loaded; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
exp_adjust = adjust.forward(item)
|
exp_adjust = adjust.forward(item)
|
||||||
get_pred(self.explainer, exp_adjust)
|
get_pred(self.explainer, exp_adjust)
|
||||||
_save_explanation(exp_adjust, path)
|
_save_explanation(exp_adjust, path)
|
||||||
|
logging.debug(
|
||||||
|
"ADJUST || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
return exp_adjust
|
return exp_adjust
|
||||||
|
|
||||||
def get_threshold(self, item: Explanation, path: str):
|
def get_threshold(self, item: Explanation, path: str):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
exp_threshold = self.explainer._post_process(item)
|
exp_threshold = self.explainer._post_process(item)
|
||||||
|
logging.debug(
|
||||||
|
"THRESHOLD || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
exp_threshold = _load_explanation(path)
|
exp_threshold = _load_explanation(path)
|
||||||
|
logging.debug(
|
||||||
|
"THRESHOLD || Loaded; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
exp_threshold = self.explainer._post_process(item)
|
exp_threshold = self.explainer._post_process(item)
|
||||||
get_pred(self.explainer, exp_threshold)
|
get_pred(self.explainer, exp_threshold)
|
||||||
_save_explanation(exp_threshold, path)
|
_save_explanation(exp_threshold, path)
|
||||||
|
logging.debug(
|
||||||
|
"THRESHOLD || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
if is_empty_graph(exp_threshold):
|
||||||
|
logging.warning(
|
||||||
|
"THRESHOLD || Generated; Path %s; EMPTY GRAPH; FAILED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
return None
|
||||||
return exp_threshold
|
return exp_threshold
|
||||||
|
|
||||||
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
out_metric = metric.forward(item)
|
out_metric = metric.forward(item)
|
||||||
|
logging.debug(
|
||||||
|
"METRIC || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out_metric = read_json(path)
|
out_metric = read_json(path)
|
||||||
|
logging.debug(
|
||||||
|
"METRIC || Loaded; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
out_metric = metric.forward(item)
|
out_metric = metric.forward(item)
|
||||||
data = {f"{metric.name}": out_metric}
|
data = {f"{metric.name}": out_metric}
|
||||||
write_json(data, path)
|
write_json(data, path)
|
||||||
|
if out_metric is None:
|
||||||
|
logging.debug(
|
||||||
|
"METRIC || Generated; Path %s; FAILED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.debug(
|
||||||
|
"METRIC || Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
return out_metric
|
return out_metric
|
||||||
|
|
||||||
def get_stat(self, item: Data, path: str):
|
def get_stat(self, item: Data, path: str):
|
||||||
|
@ -614,9 +692,66 @@ class ExplainingOutline(object):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
|
logging.debug(
|
||||||
|
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
data_attack = _load_explanation(path)
|
data_attack = _load_explanation(path)
|
||||||
|
logging.debug(
|
||||||
|
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
_save_explanation(data_attack, path)
|
_save_explanation(data_attack, path)
|
||||||
|
logging.debug(
|
||||||
|
"ATTACK || Generated %s; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
return data_attack
|
return data_attack
|
||||||
|
|
||||||
|
def setup_experiment(self):
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
self.out_dir = os.path.join(
|
||||||
|
self.explaining_cfg.out_dir,
|
||||||
|
self.cfg.dataset.name,
|
||||||
|
self.model_signature,
|
||||||
|
)
|
||||||
|
makedirs(self.out_dir)
|
||||||
|
|
||||||
|
now_str = now.strftime("month=%m-day=%d-year=%Y-hour=%H-minute=%M-second=%S")
|
||||||
|
set_printing(f"{self.out_dir}/logging-{now_str}.log")
|
||||||
|
|
||||||
|
dump_cfg(self.cfg, os.path.join(self.out_dir, "config.yaml"))
|
||||||
|
write_json(self.model_info, os.path.join(self.out_dir, "info.json"))
|
||||||
|
|
||||||
|
self.explainer_path = os.path.join(
|
||||||
|
self.out_dir,
|
||||||
|
self.explaining_cfg.explainer.name,
|
||||||
|
obj_config_to_str(self.explaining_algorithm),
|
||||||
|
)
|
||||||
|
makedirs(self.explainer_path)
|
||||||
|
dump_cfg(
|
||||||
|
self.explainer_cfg,
|
||||||
|
os.path.join(self.explainer_path, "explainer_cfg.yaml"),
|
||||||
|
)
|
||||||
|
dump_cfg(
|
||||||
|
self.explaining_cfg,
|
||||||
|
os.path.join(self.explainer_path, self.explaining_cfg.cfg_dest),
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Setting up experiment")
|
||||||
|
logging.info("Date and Time: %s", now)
|
||||||
|
logging.info("Save experiment to %s", self.out_dir)
|
||||||
|
logging.info(self.cfg)
|
||||||
|
logging.info(self.explaining_cfg)
|
||||||
|
logging.info(self.explainer_cfg)
|
||||||
|
logging.info(self.model)
|
||||||
|
logging.info(obj_config_to_log(self.model_info))
|
||||||
|
for metric in self.metrics + self.attacks:
|
||||||
|
logging.info(obj_config_to_str(metric))
|
||||||
|
for threshold_conf in self.thresholds_configs:
|
||||||
|
logging.info(obj_config_to_str(threshold_conf))
|
||||||
|
logging.info("Proceeding to explanations..")
|
||||||
|
|
|
@ -16,7 +16,6 @@ def _get_explanation(explainer, item):
|
||||||
target=item.y,
|
target=item.y,
|
||||||
)
|
)
|
||||||
if not explanation_verification(explanation):
|
if not explanation_verification(explanation):
|
||||||
# WARNING + LOG
|
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
explanation = explanation.to(cfg.accelerator)
|
explanation = explanation.to(cfg.accelerator)
|
||||||
|
@ -50,6 +49,8 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||||
is_nan = mask.isnan().any().item()
|
is_nan = mask.isnan().any().item()
|
||||||
is_inf = mask.isinf().any().item()
|
is_inf = mask.isinf().any().item()
|
||||||
is_ok = exp.validate()
|
is_ok = exp.validate()
|
||||||
|
is_const = mask.max() == mask.min()
|
||||||
|
|
||||||
if is_nan or is_inf or not is_ok:
|
if is_nan or is_inf or not is_ok:
|
||||||
is_good = False
|
is_good = False
|
||||||
return is_good
|
return is_good
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
|
|
||||||
|
|
||||||
def read_json(path: str) -> dict:
|
def read_json(path: str) -> dict:
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
|
@ -68,3 +72,37 @@ def obj_config_to_str(obj) -> str:
|
||||||
else:
|
else:
|
||||||
config = get_dict_config(obj.__dict__)
|
config = get_dict_config(obj.__dict__)
|
||||||
return "-".join([f"{k}={v}" for k, v in config.items()])
|
return "-".join([f"{k}={v}" for k, v in config.items()])
|
||||||
|
|
||||||
|
|
||||||
|
def obj_config_to_log(obj) -> str:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
config = get_dict_config(obj)
|
||||||
|
for k, v in config.items():
|
||||||
|
logging.info(f"{k} : {v}")
|
||||||
|
else:
|
||||||
|
config = get_dict_config(obj.__dict__)
|
||||||
|
for k, v in config.items():
|
||||||
|
logging.info(f"{k} : {v}")
|
||||||
|
|
||||||
|
|
||||||
|
def set_printing(logger_path):
|
||||||
|
"""
|
||||||
|
Set up printing options
|
||||||
|
|
||||||
|
"""
|
||||||
|
logging.root.handlers = []
|
||||||
|
logging_cfg = {
|
||||||
|
"level": logging.INFO,
|
||||||
|
"format": "%(asctime)s:%(levelname)s:%(message)s",
|
||||||
|
}
|
||||||
|
h_file = logging.FileHandler(logger_path)
|
||||||
|
h_stdout = logging.StreamHandler(sys.stdout)
|
||||||
|
if explaining_cfg.print == "file":
|
||||||
|
logging_cfg["handlers"] = [h_file]
|
||||||
|
elif explaining_cfg.print == "stdout":
|
||||||
|
logging_cfg["handlers"] = [h_stdout]
|
||||||
|
elif explaining_cfg.print == "both":
|
||||||
|
logging_cfg["handlers"] = [h_file, h_stdout]
|
||||||
|
else:
|
||||||
|
raise ValueError("Print option not supported")
|
||||||
|
logging.basicConfig(**logging_cfg)
|
||||||
|
|
118
main.py
118
main.py
|
@ -3,8 +3,8 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric import seed_everything
|
from torch_geometric import seed_everything
|
||||||
|
@ -13,6 +13,7 @@ from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from explaining_framework.config.explaining_config import explaining_cfg
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
|
@ -20,60 +21,36 @@ from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
obj_config_to_str, read_json,
|
obj_config_to_str, read_json,
|
||||||
write_json)
|
set_printing, write_json)
|
||||||
|
|
||||||
# inference, time, force,
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
out_dir = os.path.join(
|
|
||||||
outline.explaining_cfg.out_dir,
|
|
||||||
outline.cfg.dataset.name,
|
|
||||||
outline.model_signature,
|
|
||||||
)
|
|
||||||
makedirs(out_dir)
|
|
||||||
|
|
||||||
dump_cfg(outline.cfg, os.path.join(out_dir, "config.yaml"))
|
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
||||||
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
|
|
||||||
|
|
||||||
explainer_path = os.path.join(
|
|
||||||
out_dir,
|
|
||||||
outline.explaining_cfg.explainer.name,
|
|
||||||
obj_config_to_str(outline.explaining_algorithm),
|
|
||||||
)
|
|
||||||
makedirs(explainer_path)
|
|
||||||
dump_cfg(
|
|
||||||
outline.explainer_cfg,
|
|
||||||
os.path.join(explainer_path, "explainer_cfg.yaml"),
|
|
||||||
)
|
|
||||||
dump_cfg(
|
|
||||||
outline.explaining_cfg,
|
|
||||||
os.path.join(explainer_path, explaining_cfg.cfg_dest),
|
|
||||||
)
|
|
||||||
|
|
||||||
item, index = outline.get_item()
|
item, index = outline.get_item()
|
||||||
while not (item is None or index is None):
|
while not (item is None or index is None):
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
attack_path = os.path.join(
|
attack_path = os.path.join(
|
||||||
out_dir, attack.__class__.__name__, obj_config_to_str(attack)
|
outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
|
||||||
)
|
)
|
||||||
makedirs(attack_path)
|
makedirs(attack_path)
|
||||||
data_attack_path = os.path.join(attack_path, f"{index}.json")
|
data_attack_path = os.path.join(attack_path, f"{index}.json")
|
||||||
data_attack = outline.get_attack(
|
data_attack = outline.get_attack(
|
||||||
attack=attack, item=item, path=data_attack_path
|
attack=attack, item=item, path=data_attack_path
|
||||||
)
|
)
|
||||||
|
|
||||||
item, index = outline.get_item()
|
item, index = outline.get_item()
|
||||||
|
|
||||||
outline.reload_dataloader()
|
outline.reload_dataloader()
|
||||||
makedirs(explainer_path)
|
|
||||||
|
|
||||||
item, index = outline.get_item()
|
item, index = outline.get_item()
|
||||||
while not (item is None or index is None):
|
while not (item is None or index is None):
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
attack_path_ = os.path.join(
|
attack_path_ = os.path.join(
|
||||||
explainer_path, attack.__class__.__name__, obj_config_to_str(attack)
|
outline.explainer_path,
|
||||||
|
attack.__class__.__name__,
|
||||||
|
obj_config_to_str(attack),
|
||||||
)
|
)
|
||||||
makedirs(attack_path_)
|
makedirs(attack_path_)
|
||||||
data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
|
data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
|
||||||
|
@ -81,47 +58,48 @@ if __name__ == "__main__":
|
||||||
attack=attack, item=item, path=data_attack_path_
|
attack=attack, item=item, path=data_attack_path_
|
||||||
)
|
)
|
||||||
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||||
for adjust in outline.adjusts:
|
pbar.update(1)
|
||||||
adjust_path = os.path.join(
|
if exp is None:
|
||||||
attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust)
|
continue
|
||||||
)
|
else:
|
||||||
makedirs(adjust_path)
|
for adjust in outline.adjusts:
|
||||||
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
adjust_path = os.path.join(
|
||||||
exp_adjust = outline.get_adjust(
|
attack_path_,
|
||||||
adjust=adjust, item=exp, path=exp_adjust_path
|
adjust.__class__.__name__,
|
||||||
)
|
obj_config_to_str(adjust),
|
||||||
for threshold_conf in outline.thresholds_configs:
|
|
||||||
outline.set_explainer_threshold_config(threshold_conf)
|
|
||||||
masking_path = os.path.join(
|
|
||||||
adjust_path,
|
|
||||||
"ThresholdConfig",
|
|
||||||
obj_config_to_str(threshold_conf),
|
|
||||||
)
|
)
|
||||||
makedirs(masking_path)
|
makedirs(adjust_path)
|
||||||
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
||||||
exp_masked = outline.get_threshold(
|
exp_adjust = outline.get_adjust(
|
||||||
item=exp_adjust, path=exp_masked_path
|
adjust=adjust, item=exp, path=exp_adjust_path
|
||||||
)
|
)
|
||||||
for metric in outline.metrics:
|
for threshold_conf in outline.thresholds_configs:
|
||||||
metric_path = os.path.join(
|
outline.set_explainer_threshold_config(threshold_conf)
|
||||||
masking_path,
|
masking_path = os.path.join(
|
||||||
metric.__class__.__name__,
|
adjust_path,
|
||||||
obj_config_to_str(metric),
|
|
||||||
)
|
|
||||||
makedirs(metric_path)
|
|
||||||
metric_path = os.path.join(metric_path, f"{index}.json")
|
|
||||||
out_metric = outline.get_metric(
|
|
||||||
metric=metric, item=exp_masked, path=metric_path
|
|
||||||
)
|
|
||||||
print("#################################")
|
|
||||||
print("Attack", attack.name)
|
|
||||||
print(
|
|
||||||
"ThresholdConfig",
|
"ThresholdConfig",
|
||||||
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
obj_config_to_str(threshold_conf),
|
||||||
)
|
)
|
||||||
print("Metric", metric.name)
|
makedirs(masking_path)
|
||||||
print("Val", out_metric)
|
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
||||||
print("Index", index)
|
exp_masked = outline.get_threshold(
|
||||||
print("#################################")
|
item=exp_adjust, path=exp_masked_path
|
||||||
|
)
|
||||||
|
if exp_masked is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for metric in outline.metrics:
|
||||||
|
metric_path = os.path.join(
|
||||||
|
masking_path,
|
||||||
|
metric.__class__.__name__,
|
||||||
|
obj_config_to_str(metric),
|
||||||
|
)
|
||||||
|
makedirs(metric_path)
|
||||||
|
metric_path = os.path.join(metric_path, f"{index}.json")
|
||||||
|
out_metric = outline.get_metric(
|
||||||
|
metric=metric, item=exp_masked, path=metric_path
|
||||||
|
)
|
||||||
|
|
||||||
item, index = outline.get_item()
|
item, index = outline.get_item()
|
||||||
|
with open(os.path.join(outline.out_dir, "done"), "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
Loading…
Reference in New Issue