explaining_framework/main.py

128 lines
5.3 KiB
Python
Raw Normal View History

2022-12-29 22:29:32 +00:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
2023-01-02 22:37:40 +00:00
import copy
import logging
2022-12-29 22:29:32 +00:00
import os
2023-01-02 22:37:40 +00:00
import torch
from torch_geometric import seed_everything
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.utils.device import auto_select_device
from tqdm import tqdm
2022-12-29 22:29:32 +00:00
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
2022-12-30 18:41:56 +00:00
from explaining_framework.utils.explanation.adjust import Adjust
2023-01-31 08:47:57 +00:00
from explaining_framework.utils.io import (dump_cfg, is_exists, notify_done,
2023-01-08 19:12:38 +00:00
obj_config_to_str, read_json,
set_printing, write_json)
2022-12-30 18:41:56 +00:00
if __name__ == "__main__":
args = parse_args()
2023-01-10 10:27:54 +00:00
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
for attack in outline.attacks:
2023-01-31 08:47:57 +00:00
done_attack_raw = True
2023-01-13 10:22:21 +00:00
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
):
item = item.to(outline.cfg.accelerator)
2023-01-08 19:12:38 +00:00
attack_path = os.path.join(
outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
2023-01-08 19:12:38 +00:00
)
makedirs(attack_path)
data_attack_path = os.path.join(attack_path, f"{index}.json")
2023-01-31 09:19:17 +00:00
if args.force_fastforward:
if os.path.exists(data_attack_path):
continue
2023-01-08 19:12:38 +00:00
data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path
)
2023-01-31 08:47:57 +00:00
if data_attack is None:
done = False
continue
if done_attack_raw:
notify_done(attack_path)
for attack in outline.attacks:
2023-01-13 10:22:21 +00:00
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
logging.info(
"Running %s: %s"
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
)
2023-01-31 08:47:57 +00:00
done_attack_exp = True
done_exp = True
for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
):
item = item.to(outline.cfg.accelerator)
2023-01-08 19:12:38 +00:00
attack_path_ = os.path.join(
outline.explainer_path,
attack.__class__.__name__,
obj_config_to_str(attack),
2023-01-04 09:41:34 +00:00
)
2023-01-08 19:12:38 +00:00
makedirs(attack_path_)
data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
2023-01-31 09:19:17 +00:00
if args.force_fastforward:
if os.path.exists(data_attack_path_):
continue
2023-01-08 19:12:38 +00:00
attack_data = outline.get_attack(
attack=attack, item=item, path=data_attack_path_
)
2023-01-10 17:49:38 +00:00
if attack_data is None:
2023-01-31 08:47:57 +00:00
done_attack_exp = False
2023-01-10 17:49:38 +00:00
continue
2023-01-31 08:47:57 +00:00
2023-01-08 19:12:38 +00:00
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
2023-01-31 08:47:57 +00:00
if exp is None:
2023-01-31 08:47:57 +00:00
done_exp = False
continue
else:
for adjust in outline.adjusts:
adjust_path = os.path.join(
attack_path_,
adjust.__class__.__name__,
obj_config_to_str(adjust),
2023-01-04 09:41:34 +00:00
)
makedirs(adjust_path)
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
exp_adjust = outline.get_adjust(
adjust=adjust, item=exp, path=exp_adjust_path
2023-01-04 09:41:34 +00:00
)
for threshold_conf in outline.thresholds_configs:
outline.set_explainer_threshold_config(threshold_conf)
masking_path = os.path.join(
adjust_path,
2023-01-08 19:12:38 +00:00
"ThresholdConfig",
obj_config_to_str(threshold_conf),
)
makedirs(masking_path)
exp_masked_path = os.path.join(masking_path, f"{index}.json")
exp_masked = outline.get_threshold(
item=exp_adjust, path=exp_masked_path
2023-01-08 19:12:38 +00:00
)
if exp_masked is None:
2023-01-31 08:47:57 +00:00
done = False
continue
else:
for metric in outline.metrics:
metric_path = os.path.join(
masking_path,
metric.__class__.__name__,
obj_config_to_str(metric),
)
makedirs(metric_path)
metric_path = os.path.join(metric_path, f"{index}.json")
out_metric = outline.get_metric(
metric=metric, item=exp_masked, path=metric_path
)
2023-01-31 08:47:57 +00:00
if done_exp and done_exp:
notify_done(attack_path_)