2022-12-29 22:29:32 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
#
|
|
|
|
|
|
|
|
import os
|
2022-12-30 18:34:41 +00:00
|
|
|
import time
|
|
|
|
|
|
|
|
from torch_geometric import seed_everything
|
|
|
|
from torch_geometric.data.makedirs import makedirs
|
|
|
|
from torch_geometric.explain import Explainer
|
|
|
|
from torch_geometric.explain.config import ThresholdConfig
|
|
|
|
from torch_geometric.graphgym.config import cfg
|
|
|
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
|
|
|
2022-12-29 22:29:32 +00:00
|
|
|
from explaining_framework.config.explaining_config import explaining_cfg
|
|
|
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
2022-12-30 18:34:41 +00:00
|
|
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
2022-12-30 18:41:56 +00:00
|
|
|
from explaining_framework.utils.explanation.adjust import Adjust
|
2022-12-30 18:34:41 +00:00
|
|
|
from explaining_framework.utils.io import (obj_config_to_str, read_json,
|
|
|
|
write_json, write_yaml)
|
|
|
|
|
|
|
|
# inference, time, force,
|
|
|
|
|
2022-12-30 18:41:56 +00:00
|
|
|
|
|
|
|
def get_pred(explanation, force=False):
|
2022-12-30 18:34:41 +00:00
|
|
|
dict_ = explanation.to_dict()
|
2022-12-30 18:41:56 +00:00
|
|
|
if dict_.get("pred") is None or dict_.get("pred_masked") or force:
|
2022-12-30 18:34:41 +00:00
|
|
|
pred = explainer.get_prediction(explanation)
|
2022-12-30 18:41:56 +00:00
|
|
|
pred_masked = explainer.get_masked_prediction(
|
|
|
|
x=explanation.x,
|
|
|
|
edge_index=explanation.edge_index,
|
|
|
|
node_mask=explanation.node_mask,
|
|
|
|
edge_mask=explanation.edge_mask,
|
|
|
|
)
|
|
|
|
explanation.__setattr__("pred", pred)
|
|
|
|
explanation.__setattr__("pred_masked", pred_masked)
|
2022-12-30 18:34:41 +00:00
|
|
|
return explanation
|
|
|
|
else:
|
|
|
|
return explanation
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parse_args()
|
|
|
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
|
|
|
auto_select_device()
|
|
|
|
|
|
|
|
# Load components
|
|
|
|
dataset = outline.dataset.to(cfg.accelerator)
|
|
|
|
model = outline.model.to(cfg.accelerator)
|
|
|
|
model_info = outline.model_info
|
|
|
|
metrics = outline.metrics
|
|
|
|
explaining_algorithm = outline.explaining_algorithm
|
|
|
|
attacks = outline.attacks
|
|
|
|
explainer_cfg = outline.explainer_cfg
|
|
|
|
model_signature = outline.model_signature
|
|
|
|
|
|
|
|
# Set seed
|
|
|
|
seed_everything(explaining_cfg.seed)
|
|
|
|
|
|
|
|
# Global path
|
|
|
|
global_path = os.path.join(explaining_cfg.out_dir, model_signature)
|
|
|
|
makedirs(global_path)
|
|
|
|
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
2022-12-30 18:41:56 +00:00
|
|
|
write_json(model_info, os.path.join(global_path, "info.json"))
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2022-12-30 18:41:56 +00:00
|
|
|
global_path = os.path.join(
|
|
|
|
global_path,
|
|
|
|
explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm),
|
|
|
|
)
|
2022-12-30 18:34:41 +00:00
|
|
|
makedirs(global_path)
|
|
|
|
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
|
|
|
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
|
|
|
|
|
|
|
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
|
|
|
|
makedirs(global_path)
|
|
|
|
explainer = Explainer(
|
|
|
|
model=model,
|
|
|
|
algorithm=explaining_algorithm,
|
|
|
|
explainer_config=dict(
|
|
|
|
explanation_type=explaining_cfg.explanation_type,
|
|
|
|
node_mask_type="object",
|
|
|
|
edge_mask_type="object",
|
|
|
|
),
|
|
|
|
model_config=dict(
|
|
|
|
mode="regression",
|
|
|
|
task_level=cfg.dataset.task,
|
|
|
|
return_type=explaining_cfg.model_config.return_type,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
# Save explaining configuration
|
|
|
|
for index, item in enumerate(dataset):
|
|
|
|
save_raw_path = os.path.join(global_path, "raw")
|
|
|
|
makedirs(save_raw_path)
|
|
|
|
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
|
|
|
|
|
|
|
if is_exists(explanation_path):
|
|
|
|
if explaining_cfg.explainer.force:
|
|
|
|
explanation = explainer(
|
|
|
|
x=item.x,
|
|
|
|
edge_index=item.edge_index,
|
|
|
|
index=item.y,
|
|
|
|
target=item.y,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
explanation = load_explanation(explanation_path)
|
|
|
|
else:
|
|
|
|
explanation = explainer(
|
|
|
|
x=item.x,
|
|
|
|
edge_index=item.edge_index,
|
|
|
|
index=item.y,
|
|
|
|
target=item.y,
|
|
|
|
)
|
2022-12-30 18:41:56 +00:00
|
|
|
explanation = get_pred(explanation, force=False)
|
|
|
|
save_explanation(explanation, explanation_path)
|
|
|
|
for apply_relu in [True, False]:
|
|
|
|
for apply_absolute in [True, False]:
|
|
|
|
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
|
|
|
save_raw_path = os.path.join(
|
|
|
|
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
|
|
|
)
|
2022-12-30 18:34:41 +00:00
|
|
|
makedirs(save_raw_path)
|
2022-12-30 18:41:56 +00:00
|
|
|
explanation = adjust.forward(explanation)
|
2022-12-30 18:34:41 +00:00
|
|
|
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
2022-12-30 18:41:56 +00:00
|
|
|
explanation = get_pred(explanation, force=True)
|
|
|
|
save_explanation(explanation, explanation_path)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2022-12-30 18:41:56 +00:00
|
|
|
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
|
|
|
for threshold_value in explaining_cfg.threshold_config.value:
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2022-12-30 18:41:56 +00:00
|
|
|
masking_path = os.path.join(
|
|
|
|
save_raw_path,
|
|
|
|
f"threshold={threshold_approach}-value={value}",
|
|
|
|
)
|
|
|
|
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
2022-12-30 18:34:41 +00:00
|
|
|
if is_exists(exp_threshold_path):
|
|
|
|
explanation = load_explanation(exp_threshold_path)
|
|
|
|
else:
|
2022-12-30 18:41:56 +00:00
|
|
|
threshold_conf = {
|
|
|
|
"threshold_type": threshold_approach,
|
|
|
|
"value": threshold_value,
|
|
|
|
}
|
|
|
|
explainer.threshold_config = ThresholdConfig.cast(
|
|
|
|
threshold_conf
|
|
|
|
)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
|
|
|
expl = copy.copy(explanation)
|
|
|
|
exp_threshold = explainer._post_process(expl)
|
2022-12-30 18:41:56 +00:00
|
|
|
exp_threshold = get_pred(exp_threshold, force=True)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2022-12-30 18:41:56 +00:00
|
|
|
save_explanation(exp_threshold, exp_threshold_path)
|
2022-12-30 18:34:41 +00:00
|
|
|
for metric in metrics:
|
2022-12-30 18:41:56 +00:00
|
|
|
metric_path = os.path.join(
|
|
|
|
masking_path, f"{obj_config_to_str(metric)}"
|
|
|
|
)
|
|
|
|
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
2022-12-30 18:34:41 +00:00
|
|
|
continue
|
|
|
|
else:
|
|
|
|
out = metric.forward(exp_threshold)
|
2022-12-30 18:41:56 +00:00
|
|
|
write_json({f"{metric.name}": out})
|