explaining_framework/main.py

141 lines
5.9 KiB
Python
Raw Normal View History

2022-12-29 22:29:32 +00:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
import os
import time
from torch_geometric import seed_everything
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.utils.device import auto_select_device
2022-12-29 22:29:32 +00:00
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.io import (obj_config_to_str, read_json,
write_json, write_yaml)
from explaining_framework.utils.explanation.adjust import Adjust
# inference, time, force,
def get_pred(explanation,force=False):
dict_ = explanation.to_dict()
if dict_.get('pred') is None or dict_.get('pred_masked') or force:
pred = explainer.get_prediction(explanation)
pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask)
explanation.__setattr__('pred',pred)
explanation.__setattr__('pred_masked',pred_masked)
return explanation
else:
return explanation
if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file)
auto_select_device()
# Load components
dataset = outline.dataset.to(cfg.accelerator)
model = outline.model.to(cfg.accelerator)
model_info = outline.model_info
metrics = outline.metrics
explaining_algorithm = outline.explaining_algorithm
attacks = outline.attacks
explainer_cfg = outline.explainer_cfg
model_signature = outline.model_signature
# Set seed
seed_everything(explaining_cfg.seed)
# Global path
global_path = os.path.join(explaining_cfg.out_dir, model_signature)
makedirs(global_path)
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
write_json(model_info, os.path.join(global_path, "info.json"))
global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm))
makedirs(global_path)
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
makedirs(global_path)
explainer = Explainer(
model=model,
algorithm=explaining_algorithm,
explainer_config=dict(
explanation_type=explaining_cfg.explanation_type,
node_mask_type="object",
edge_mask_type="object",
),
model_config=dict(
mode="regression",
task_level=cfg.dataset.task,
return_type=explaining_cfg.model_config.return_type,
),
)
# Save explaining configuration
for index, item in enumerate(dataset):
save_raw_path = os.path.join(global_path, "raw")
makedirs(save_raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
if is_exists(explanation_path):
if explaining_cfg.explainer.force:
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
else:
explanation = load_explanation(explanation_path)
else:
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
explanation = get_pred(explanation,force=False)
save_explanation(explanation,explanation_path)
for apply_relu in [True,False]:
for apply_absolute in [True,False]:
adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute)
save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}')
makedirs(save_raw_path)
explanation = adjust.forward(explanation)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
explanation = get_pred(explanation,force=True)
save_explanation(explanation,explanation_path)
for threshold_approach in ['hard','topk','topk_hard']:
for threshold_value in explaining_cfg.threshold_config.value:
masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}')
exp_threshold_path = os.path.join(masking_path,f'{index}.json')
if is_exists(exp_threshold_path):
explanation = load_explanation(exp_threshold_path)
else:
threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value}
explainer.threshold_config = ThresholdConfig.cast(threshold_conf)
expl = copy.copy(explanation)
exp_threshold = explainer._post_process(expl)
exp_threshold= get_pred(exp_threshold,force=True)
save_explanation(exp_threshold,exp_threshold_path)
for metric in metrics:
metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}')
if is_exists(os.path.join(metric_path,f'{index}.json')):
continue
else:
out = metric.forward(exp_threshold)
write_json({f'{metric.name}':out})