#!/usr/bin/env python # -*- coding: utf-8 -*- # import copy import os import time import torch from torch_geometric import seed_everything from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.device import auto_select_device from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.io import ( explanation_verification, load_explanation, save_explanation) from explaining_framework.utils.io import (is_exists, obj_config_to_str, read_json, write_json, write_yaml) # inference, time, force, def get_pred(explainer, explanation): pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[ 0 ] setattr(explanation, "pred", pred) data = explanation.to_dict() if not data.get("node_mask") is None or not data.get("edge_mask") is None: pred_masked = explainer.get_masked_prediction( x=explanation.x, edge_index=explanation.edge_index, node_mask=data.get("node_mask"), edge_mask=data.get("edge_mask"), )[0] setattr(explanation, "pred_exp", pred_masked) def get_explanation(explainer, item): explanation = explainer( x=item.x, edge_index=item.edge_index, index=int(item.y), target=item.y, ) assert explanation_verification(explanation) return explanation if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) auto_select_device() # Load components dataset = outline.dataset model = outline.model.to(cfg.accelerator) model = model.eval() model_info = outline.model_info metrics = outline.metrics explaining_algorithm = outline.explaining_algorithm attacks = outline.attacks explainer_cfg = outline.explainer_cfg model_signature = outline.model_signature # Set seed seed_everything(explaining_cfg.seed) # Global path global_path = os.path.join(explaining_cfg.out_dir, model_signature) makedirs(global_path) write_yaml(cfg, os.path.join(global_path, "config.yaml")) write_json(model_info, os.path.join(global_path, "info.json")) global_path = os.path.join( global_path, explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm), ) makedirs(global_path) write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) makedirs(global_path) explainer = Explainer( model=model, algorithm=explaining_algorithm, explainer_config=dict( explanation_type=explaining_cfg.explanation_type, node_mask_type="object", edge_mask_type="object", ), model_config=dict( mode="regression", task_level=cfg.dataset.task, return_type=explaining_cfg.model_config.return_type, ), ) if not explaining_cfg.dataset.specific_items is None: indexes = explaining_cfg.dataset.specific_items else: indexes = range(len(dataset)) # Save explaining configuration for index, item in zip(indexes, dataset): item = item.to(cfg.accelerator) save_raw_path = os.path.join(global_path, "raw") makedirs(save_raw_path) explanation_path = os.path.join(save_raw_path, f"{index}.json") if is_exists(explanation_path): if explaining_cfg.explainer.force: explanation = get_explanation(explainer, item) else: explanation = load_explanation(explanation_path) else: explanation = get_explanation(explainer, item) explanation = explanation.to(cfg.accelerator) get_pred(explainer=explainer, explanation=explanation) save_explanation(explanation, explanation_path) for apply_relu in [True, False]: for apply_absolute in [True, False]: adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) save_raw_path_ = os.path.join( global_path, f"adjust-{obj_config_to_str(adjust)}" ) explanation__ = copy.copy(explanation).to(cfg.accelerator) makedirs(save_raw_path_) explanation = adjust.forward(explanation__) explanation_path = os.path.join(save_raw_path_, f"{index}.json") get_pred(explainer, explanation__) save_explanation(explanation__, explanation_path) for threshold_approach in ["hard", "topk", "topk_hard"]: if threshold_approach == "hard": threshold_values = explaining_cfg.threshold_config.value elif "topk" in threshold_approach: threshold_values = [3, 5, 10, 20] for threshold_value in threshold_values: masking_path = os.path.join( save_raw_path_, f"threshold={threshold_approach}-value={threshold_value}", ) makedirs(masking_path) exp_threshold_path = os.path.join(masking_path, f"{index}.json") if is_exists(exp_threshold_path): exp_threshold = load_explanation(exp_threshold_path) else: threshold_conf = { "threshold_type": threshold_approach, "value": threshold_value, } explainer.threshold_config = ThresholdConfig.cast( threshold_conf ) expl = copy.copy(explanation__).to(cfg.accelerator) exp_threshold = explainer._post_process(expl) exp_threshold = exp_threshold.to(cfg.accelerator) get_pred(explainer, exp_threshold) save_explanation(exp_threshold, exp_threshold_path) for metric in metrics: metric_path = os.path.join( masking_path, f"{obj_config_to_str(metric)}" ) makedirs(metric_path) if is_exists(os.path.join(metric_path, f"{index}.json")): continue else: out = metric.forward(exp_threshold) write_json( {f"{metric.name}": out}, os.path.join(metric_path, f"{index}.json"), )