2022-12-29 22:29:32 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
#
|
|
|
|
|
2023-01-02 22:37:40 +00:00
|
|
|
import copy
|
2022-12-29 22:29:32 +00:00
|
|
|
import os
|
2022-12-30 18:34:41 +00:00
|
|
|
import time
|
|
|
|
|
2023-01-02 22:37:40 +00:00
|
|
|
import torch
|
2022-12-30 18:34:41 +00:00
|
|
|
from torch_geometric import seed_everything
|
|
|
|
from torch_geometric.data.makedirs import makedirs
|
|
|
|
from torch_geometric.explain import Explainer
|
|
|
|
from torch_geometric.explain.config import ThresholdConfig
|
|
|
|
from torch_geometric.graphgym.config import cfg
|
|
|
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
|
|
|
2022-12-29 22:29:32 +00:00
|
|
|
from explaining_framework.config.explaining_config import explaining_cfg
|
|
|
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
2022-12-30 18:34:41 +00:00
|
|
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
2022-12-30 18:41:56 +00:00
|
|
|
from explaining_framework.utils.explanation.adjust import Adjust
|
2023-01-02 22:37:40 +00:00
|
|
|
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
|
|
|
read_json, write_json, write_yaml)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
|
|
|
# inference, time, force,
|
|
|
|
|
2022-12-30 18:41:56 +00:00
|
|
|
|
2022-12-30 18:34:41 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parse_args()
|
|
|
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
2023-01-04 11:56:37 +00:00
|
|
|
print(outline.explaining_cfg)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
|
|
|
|
makedirs(out_dir)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml"))
|
|
|
|
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
explainer_path = os.path.join(
|
|
|
|
out_dir,
|
|
|
|
outline.explaining_cfg.explainer.name
|
|
|
|
+ "_"
|
|
|
|
+ obj_config_to_str(outline.explaining_algorithm),
|
|
|
|
)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
makedirs(explainer_path)
|
|
|
|
write_yaml(
|
|
|
|
outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest)
|
|
|
|
)
|
|
|
|
write_yaml(
|
|
|
|
outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml")
|
|
|
|
)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
specific_explainer_path = os.path.join(
|
|
|
|
explainer_path, obj_config_to_str(outline.explaining_algorithm)
|
|
|
|
)
|
|
|
|
makedirs(specific_explainer_path)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
raw_path = os.path.join(specific_explainer_path, "raw")
|
|
|
|
makedirs(raw_path)
|
2022-12-30 18:34:41 +00:00
|
|
|
|
2023-01-04 09:41:34 +00:00
|
|
|
item, index = outline.get_item()
|
|
|
|
while not (item is None or index is None):
|
|
|
|
explanation_path = os.path.join(raw_path, f"{index}.json")
|
|
|
|
raw_exp = outline.get_explanation(item=item, path=explanation_path)
|
|
|
|
for adjust in outline.adjusts:
|
|
|
|
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
|
|
|
|
makedirs(adjust_path)
|
2023-01-04 11:56:37 +00:00
|
|
|
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
2023-01-04 09:41:34 +00:00
|
|
|
exp_adjust = outline.get_adjust(
|
|
|
|
adjust=adjust, item=raw_exp, path=exp_adjust_path
|
|
|
|
)
|
|
|
|
for threshold_conf in outline.thresholds_configs:
|
|
|
|
outline.set_explainer_threshold_config(threshold_conf)
|
|
|
|
masking_path = os.path.join(
|
2023-01-04 11:56:37 +00:00
|
|
|
adjust_path,
|
2023-01-04 09:41:34 +00:00
|
|
|
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
|
|
|
)
|
|
|
|
makedirs(masking_path)
|
|
|
|
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
|
|
|
exp_masked = outline.get_threshold(
|
|
|
|
item=exp_adjust, path=exp_masked_path
|
|
|
|
)
|
|
|
|
for metric in outline.metrics:
|
|
|
|
metric_path = os.path.join(
|
|
|
|
masking_path, f"{obj_config_to_str(metric)}"
|
|
|
|
)
|
|
|
|
makedirs(metric_path)
|
|
|
|
metric_path = os.path.join(metric_path, f"{index}.json")
|
|
|
|
out_metric = outline.get_metric(
|
|
|
|
metric=metric, item=exp_masked, path=metric_path
|
|
|
|
)
|