explaining_framework/main.py

112 lines
5.1 KiB
Python
Raw Normal View History

2022-12-29 22:29:32 +00:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
2023-01-02 22:37:40 +00:00
import copy
2022-12-29 22:29:32 +00:00
import os
import time
2023-01-02 22:37:40 +00:00
import torch
from torch_geometric import seed_everything
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.utils.device import auto_select_device
2022-12-29 22:29:32 +00:00
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
2022-12-30 18:41:56 +00:00
from explaining_framework.utils.explanation.adjust import Adjust
2023-01-02 22:37:40 +00:00
from explaining_framework.utils.explanation.io import (
2023-01-03 16:12:54 +00:00
explanation_verification, get_explanation, get_pred, load_explanation,
save_explanation)
2023-01-02 22:37:40 +00:00
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml)
# inference, time, force,
2022-12-30 18:41:56 +00:00
if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file)
# Load components
2023-01-03 16:12:54 +00:00
# RAJOUTER INDEXES
# Global path
2023-01-04 08:25:41 +00:00
global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm))
makedirs(global_path)
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
2022-12-30 18:41:56 +00:00
write_json(model_info, os.path.join(global_path, "info.json"))
makedirs(global_path)
2023-01-04 08:25:41 +00:00
write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
2023-01-04 08:25:41 +00:00
global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm))
makedirs(global_path)
2023-01-03 16:12:54 +00:00
# SET UP EXPLAINER
# Save explaining configuration
2023-01-04 08:25:41 +00:00
item,index = outline.get_item()
while not(item is None or index is None):
raw_path = os.path.join(global_path, "raw")
makedirs(raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
2022-12-30 18:41:56 +00:00
for apply_relu in [True, False]:
for apply_absolute in [True, False]:
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
2023-01-02 22:37:40 +00:00
save_raw_path_ = os.path.join(
2022-12-30 18:41:56 +00:00
global_path, f"adjust-{obj_config_to_str(adjust)}"
)
2023-01-02 22:37:40 +00:00
explanation__ = copy.copy(explanation).to(cfg.accelerator)
makedirs(save_raw_path_)
explanation = adjust.forward(explanation__)
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
get_pred(explainer, explanation__)
save_explanation(explanation__, explanation_path)
2022-12-30 18:41:56 +00:00
for threshold_approach in ["hard", "topk", "topk_hard"]:
2023-01-02 22:37:40 +00:00
if threshold_approach == "hard":
threshold_values = explaining_cfg.threshold_config.value
elif "topk" in threshold_approach:
threshold_values = [3, 5, 10, 20]
for threshold_value in threshold_values:
2022-12-30 18:41:56 +00:00
masking_path = os.path.join(
2023-01-02 22:37:40 +00:00
save_raw_path_,
f"threshold={threshold_approach}-value={threshold_value}",
2022-12-30 18:41:56 +00:00
)
2023-01-02 22:37:40 +00:00
makedirs(masking_path)
2022-12-30 18:41:56 +00:00
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
if is_exists(exp_threshold_path):
2023-01-02 22:37:40 +00:00
exp_threshold = load_explanation(exp_threshold_path)
else:
2022-12-30 18:41:56 +00:00
threshold_conf = {
"threshold_type": threshold_approach,
"value": threshold_value,
}
explainer.threshold_config = ThresholdConfig.cast(
threshold_conf
)
2023-01-02 22:37:40 +00:00
expl = copy.copy(explanation__).to(cfg.accelerator)
exp_threshold = explainer._post_process(expl)
2023-01-02 22:37:40 +00:00
exp_threshold = exp_threshold.to(cfg.accelerator)
get_pred(explainer, exp_threshold)
save_explanation(exp_threshold, exp_threshold_path)
for metric in metrics:
2022-12-30 18:41:56 +00:00
metric_path = os.path.join(
masking_path, f"{obj_config_to_str(metric)}"
)
2023-01-02 22:37:40 +00:00
makedirs(metric_path)
2022-12-30 18:41:56 +00:00
if is_exists(os.path.join(metric_path, f"{index}.json")):
continue
else:
out = metric.forward(exp_threshold)
2023-01-02 22:37:40 +00:00
write_json(
{f"{metric.name}": out},
os.path.join(metric_path, f"{index}.json"),
)