109 lines
4.4 KiB
Python
109 lines
4.4 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
|
|
import copy
|
|
import logging
|
|
import os
|
|
|
|
import torch
|
|
from torch_geometric import seed_everything
|
|
from torch_geometric.data.makedirs import makedirs
|
|
from torch_geometric.explain import Explainer
|
|
from torch_geometric.explain.config import ThresholdConfig
|
|
from torch_geometric.graphgym.config import cfg
|
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
from tqdm import tqdm
|
|
|
|
from explaining_framework.config.explaining_config import explaining_cfg
|
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
|
from explaining_framework.utils.explanation.adjust import Adjust
|
|
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
|
obj_config_to_str, read_json,
|
|
set_printing, write_json)
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
|
|
|
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
|
|
|
item, index = outline.get_item()
|
|
while not (item is None or index is None):
|
|
for attack in outline.attacks:
|
|
attack_path = os.path.join(
|
|
outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
|
|
)
|
|
makedirs(attack_path)
|
|
data_attack_path = os.path.join(attack_path, f"{index}.json")
|
|
data_attack = outline.get_attack(
|
|
attack=attack, item=item, path=data_attack_path
|
|
)
|
|
|
|
item, index = outline.get_item()
|
|
|
|
outline.reload_dataloader()
|
|
item, index = outline.get_item()
|
|
while not (item is None or index is None):
|
|
for attack in outline.attacks:
|
|
attack_path_ = os.path.join(
|
|
outline.explainer_path,
|
|
attack.__class__.__name__,
|
|
obj_config_to_str(attack),
|
|
)
|
|
makedirs(attack_path_)
|
|
data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
|
|
attack_data = outline.get_attack(
|
|
attack=attack, item=item, path=data_attack_path_
|
|
)
|
|
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
|
pbar.update(1)
|
|
if exp is None:
|
|
continue
|
|
else:
|
|
for adjust in outline.adjusts:
|
|
adjust_path = os.path.join(
|
|
attack_path_,
|
|
adjust.__class__.__name__,
|
|
obj_config_to_str(adjust),
|
|
)
|
|
makedirs(adjust_path)
|
|
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
|
exp_adjust = outline.get_adjust(
|
|
adjust=adjust, item=exp, path=exp_adjust_path
|
|
)
|
|
for threshold_conf in outline.thresholds_configs:
|
|
outline.set_explainer_threshold_config(threshold_conf)
|
|
masking_path = os.path.join(
|
|
adjust_path,
|
|
"ThresholdConfig",
|
|
obj_config_to_str(threshold_conf),
|
|
)
|
|
makedirs(masking_path)
|
|
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
|
exp_masked = outline.get_threshold(
|
|
item=exp_adjust, path=exp_masked_path
|
|
)
|
|
if exp_masked is None:
|
|
continue
|
|
else:
|
|
for metric in outline.metrics:
|
|
metric_path = os.path.join(
|
|
masking_path,
|
|
metric.__class__.__name__,
|
|
obj_config_to_str(metric),
|
|
)
|
|
makedirs(metric_path)
|
|
metric_path = os.path.join(metric_path, f"{index}.json")
|
|
out_metric = outline.get_metric(
|
|
metric=metric, item=exp_masked, path=metric_path
|
|
)
|
|
|
|
item, index = outline.get_item()
|
|
|
|
with open(os.path.join(outline.out_dir, "done"), "w") as f:
|
|
f.write("")
|
|
|
|
pbar.close()
|