#!/usr/bin/env python # -*- coding: utf-8 -*- # import copy import logging import os import torch from torch_geometric import seed_everything from torch_geometric.data.makedirs import makedirs from torch_geometric.explain import Explainer from torch_geometric.explain.config import ThresholdConfig from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.utils.device import auto_select_device from tqdm import tqdm from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done, obj_config_to_str, read_json, set_printing, write_json) if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) for attack in outline.attacks: done_attack_raw = True logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ): item = item.to(outline.cfg.accelerator) attack_path = os.path.join( outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) ) makedirs(attack_path) data_attack_path = os.path.join(attack_path, f"{index}.json") if args.force_fastforward: if os.path.exists(data_attack_path): continue data_attack = outline.get_attack( attack=attack, item=item, path=data_attack_path ) if data_attack is None: done = False continue if done_attack_raw: notify_done(attack_path) for attack in outline.attacks: logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name)) logging.info( "Running %s: %s" % (outline.explainer.__class__.__name__, outline.explaining_algorithm.name), ) done_attack_exp = True done_exp = True for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ): item = item.to(outline.cfg.accelerator) attack_path_ = os.path.join( outline.explainer_path, attack.__class__.__name__, obj_config_to_str(attack), ) makedirs(attack_path_) data_attack_path_ = os.path.join(attack_path_, f"{index}.json") if args.force_fastforward: if os.path.exists(data_attack_path_): continue attack_data = outline.get_attack( attack=attack, item=item, path=data_attack_path_ ) if attack_data is None: done_attack_exp = False continue exp = outline.get_explanation(item=attack_data, path=data_attack_path_) if exp is None: done_exp = False continue else: for adjust in outline.adjusts: adjust_path = os.path.join( attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust), ) makedirs(adjust_path) exp_adjust_path = os.path.join(adjust_path, f"{index}.json") exp_adjust = outline.get_adjust( adjust=adjust, item=exp, path=exp_adjust_path ) for threshold_conf in outline.thresholds_configs: outline.set_explainer_threshold_config(threshold_conf) masking_path = os.path.join( adjust_path, "ThresholdConfig", obj_config_to_str(threshold_conf), ) makedirs(masking_path) exp_masked_path = os.path.join(masking_path, f"{index}.json") exp_masked = outline.get_threshold( item=exp_adjust, path=exp_masked_path ) if exp_masked is None: done = False continue else: for metric in outline.metrics: metric_path = os.path.join( masking_path, metric.__class__.__name__, obj_config_to_str(metric), ) makedirs(metric_path) metric_path = os.path.join(metric_path, f"{index}.json") out_metric = outline.get_metric( metric=metric, item=exp_masked, path=metric_path ) if done_exp and done_exp: notify_done(attack_path_)