#!/usr/bin/env python # -*- coding: utf-8 -*- # import os import time from torch_geometric import seed_everything from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.device import auto_select_device from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.io import (obj_config_to_str, read_json, write_json, write_yaml) # inference, time, force, def get_pred(explanation, force=False): dict_ = explanation.to_dict() if dict_.get("pred") is None or dict_.get("pred_masked") or force: pred = explainer.get_prediction(explanation) pred_masked = explainer.get_masked_prediction( x=explanation.x, edge_index=explanation.edge_index, node_mask=explanation.node_mask, edge_mask=explanation.edge_mask, ) explanation.__setattr__("pred", pred) explanation.__setattr__("pred_masked", pred_masked) return explanation else: return explanation if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) auto_select_device() # Load components dataset = outline.dataset.to(cfg.accelerator) model = outline.model.to(cfg.accelerator) model_info = outline.model_info metrics = outline.metrics explaining_algorithm = outline.explaining_algorithm attacks = outline.attacks explainer_cfg = outline.explainer_cfg model_signature = outline.model_signature # Set seed seed_everything(explaining_cfg.seed) # Global path global_path = os.path.join(explaining_cfg.out_dir, model_signature) makedirs(global_path) write_yaml(cfg, os.path.join(global_path, "config.yaml")) write_json(model_info, os.path.join(global_path, "info.json")) global_path = os.path.join( global_path, explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm), ) makedirs(global_path) write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) makedirs(global_path) explainer = Explainer( model=model, algorithm=explaining_algorithm, explainer_config=dict( explanation_type=explaining_cfg.explanation_type, node_mask_type="object", edge_mask_type="object", ), model_config=dict( mode="regression", task_level=cfg.dataset.task, return_type=explaining_cfg.model_config.return_type, ), ) # Save explaining configuration for index, item in enumerate(dataset): save_raw_path = os.path.join(global_path, "raw") makedirs(save_raw_path) explanation_path = os.path.join(save_raw_path, f"{index}.json") if is_exists(explanation_path): if explaining_cfg.explainer.force: explanation = explainer( x=item.x, edge_index=item.edge_index, index=item.y, target=item.y, ) else: explanation = load_explanation(explanation_path) else: explanation = explainer( x=item.x, edge_index=item.edge_index, index=item.y, target=item.y, ) explanation = get_pred(explanation, force=False) save_explanation(explanation, explanation_path) for apply_relu in [True, False]: for apply_absolute in [True, False]: adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) save_raw_path = os.path.join( global_path, f"adjust-{obj_config_to_str(adjust)}" ) makedirs(save_raw_path) explanation = adjust.forward(explanation) explanation_path = os.path.join(save_raw_path, f"{index}.json") explanation = get_pred(explanation, force=True) save_explanation(explanation, explanation_path) for threshold_approach in ["hard", "topk", "topk_hard"]: for threshold_value in explaining_cfg.threshold_config.value: masking_path = os.path.join( save_raw_path, f"threshold={threshold_approach}-value={value}", ) exp_threshold_path = os.path.join(masking_path, f"{index}.json") if is_exists(exp_threshold_path): explanation = load_explanation(exp_threshold_path) else: threshold_conf = { "threshold_type": threshold_approach, "value": threshold_value, } explainer.threshold_config = ThresholdConfig.cast( threshold_conf ) expl = copy.copy(explanation) exp_threshold = explainer._post_process(expl) exp_threshold = get_pred(exp_threshold, force=True) save_explanation(exp_threshold, exp_threshold_path) for metric in metrics: metric_path = os.path.join( masking_path, f"{obj_config_to_str(metric)}" ) if is_exists(os.path.join(metric_path, f"{index}.json")): continue else: out = metric.forward(exp_threshold) write_json({f"{metric.name}": out})