#!/usr/bin/env python # -*- coding: utf-8 -*- # import copy import os import time import torch from torch_geometric import seed_everything from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.device import auto_select_device from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.io import ( explanation_verification, get_explanation, get_pred, load_explanation, save_explanation) from explaining_framework.utils.io import (is_exists, obj_config_to_str, read_json, write_json, write_yaml) # inference, time, force, if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) # Load components # RAJOUTER INDEXES # Global path global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm)) makedirs(global_path) write_yaml(cfg, os.path.join(global_path, "config.yaml")) write_json(model_info, os.path.join(global_path, "info.json")) makedirs(global_path) write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm)) makedirs(global_path) # SET UP EXPLAINER # Save explaining configuration item,index = outline.get_item() while not(item is None or index is None): raw_path = os.path.join(global_path, "raw") makedirs(raw_path) explanation_path = os.path.join(save_raw_path, f"{index}.json") for apply_relu in [True, False]: for apply_absolute in [True, False]: adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) save_raw_path_ = os.path.join( global_path, f"adjust-{obj_config_to_str(adjust)}" ) explanation__ = copy.copy(explanation).to(cfg.accelerator) makedirs(save_raw_path_) explanation = adjust.forward(explanation__) explanation_path = os.path.join(save_raw_path_, f"{index}.json") get_pred(explainer, explanation__) save_explanation(explanation__, explanation_path) for threshold_approach in ["hard", "topk", "topk_hard"]: if threshold_approach == "hard": threshold_values = explaining_cfg.threshold_config.value elif "topk" in threshold_approach: threshold_values = [3, 5, 10, 20] for threshold_value in threshold_values: masking_path = os.path.join( save_raw_path_, f"threshold={threshold_approach}-value={threshold_value}", ) makedirs(masking_path) exp_threshold_path = os.path.join(masking_path, f"{index}.json") if is_exists(exp_threshold_path): exp_threshold = load_explanation(exp_threshold_path) else: threshold_conf = { "threshold_type": threshold_approach, "value": threshold_value, } explainer.threshold_config = ThresholdConfig.cast( threshold_conf ) expl = copy.copy(explanation__).to(cfg.accelerator) exp_threshold = explainer._post_process(expl) exp_threshold = exp_threshold.to(cfg.accelerator) get_pred(explainer, exp_threshold) save_explanation(exp_threshold, exp_threshold_path) for metric in metrics: metric_path = os.path.join( masking_path, f"{obj_config_to_str(metric)}" ) makedirs(metric_path) if is_exists(os.path.join(metric_path, f"{index}.json")): continue else: out = metric.forward(exp_threshold) write_json( {f"{metric.name}": out}, os.path.join(metric_path, f"{index}.json"), )