#!/usr/bin/env python # -*- coding: utf-8 -*- # import copy import logging import os import torch from torch_geometric import seed_everything from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.device import auto_select_device from tqdm import tqdm from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.io import (dump_cfg, is_exists, obj_config_to_str, read_json, set_printing, write_json) if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) item, index = outline.get_item() while not (item is None or index is None): for attack in outline.attacks: attack_path = os.path.join( outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) ) makedirs(attack_path) data_attack_path = os.path.join(attack_path, f"{index}.json") data_attack = outline.get_attack( attack=attack, item=item, path=data_attack_path ) item, index = outline.get_item() outline.reload_dataloader() item, index = outline.get_item() while not (item is None or index is None): for attack in outline.attacks: attack_path_ = os.path.join( outline.explainer_path, attack.__class__.__name__, obj_config_to_str(attack), ) makedirs(attack_path_) data_attack_path_ = os.path.join(attack_path_, f"{index}.json") attack_data = outline.get_attack( attack=attack, item=item, path=data_attack_path_ ) exp = outline.get_explanation(item=attack_data, path=data_attack_path_) pbar.update(1) if exp is None: continue else: for adjust in outline.adjusts: adjust_path = os.path.join( attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust), ) makedirs(adjust_path) exp_adjust_path = os.path.join(adjust_path, f"{index}.json") exp_adjust = outline.get_adjust( adjust=adjust, item=exp, path=exp_adjust_path ) for threshold_conf in outline.thresholds_configs: outline.set_explainer_threshold_config(threshold_conf) masking_path = os.path.join( adjust_path, "ThresholdConfig", obj_config_to_str(threshold_conf), ) makedirs(masking_path) exp_masked_path = os.path.join(masking_path, f"{index}.json") exp_masked = outline.get_threshold( item=exp_adjust, path=exp_masked_path ) if exp_masked is None: continue else: for metric in outline.metrics: metric_path = os.path.join( masking_path, metric.__class__.__name__, obj_config_to_str(metric), ) makedirs(metric_path) metric_path = os.path.join(metric_path, f"{index}.json") out_metric = outline.get_metric( metric=metric, item=exp_masked, path=metric_path ) item, index = outline.get_item() with open(os.path.join(outline.out_dir, "done"), "w") as f: f.write("") pbar.close()