From a7dbba146ed0c592845e47bd78d9da26b88845d2 Mon Sep 17 00:00:00 2001 From: araison Date: Fri, 30 Dec 2022 19:41:56 +0100 Subject: [PATCH] Fixing bug --- main.py | 79 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index 834d8ac..e52bb1f 100644 --- a/main.py +++ b/main.py @@ -15,19 +15,25 @@ from torch_geometric.graphgym.utils.device import auto_select_device from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline +from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.io import (obj_config_to_str, read_json, write_json, write_yaml) -from explaining_framework.utils.explanation.adjust import Adjust # inference, time, force, -def get_pred(explanation,force=False): + +def get_pred(explanation, force=False): dict_ = explanation.to_dict() - if dict_.get('pred') is None or dict_.get('pred_masked') or force: + if dict_.get("pred") is None or dict_.get("pred_masked") or force: pred = explainer.get_prediction(explanation) - pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask) - explanation.__setattr__('pred',pred) - explanation.__setattr__('pred_masked',pred_masked) + pred_masked = explainer.get_masked_prediction( + x=explanation.x, + edge_index=explanation.edge_index, + node_mask=explanation.node_mask, + edge_mask=explanation.edge_mask, + ) + explanation.__setattr__("pred", pred) + explanation.__setattr__("pred_masked", pred_masked) return explanation else: return explanation @@ -55,9 +61,12 @@ if __name__ == "__main__": global_path = os.path.join(explaining_cfg.out_dir, model_signature) makedirs(global_path) write_yaml(cfg, os.path.join(global_path, "config.yaml")) - write_json(model_info, os.path.join(global_path, "info.json")) + write_json(model_info, os.path.join(global_path, "info.json")) - global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm)) + global_path = os.path.join( + global_path, + explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm), + ) makedirs(global_path) write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) @@ -101,40 +110,50 @@ if __name__ == "__main__": index=item.y, target=item.y, ) - explanation = get_pred(explanation,force=False) - save_explanation(explanation,explanation_path) - for apply_relu in [True,False]: - for apply_absolute in [True,False]: - adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute) - save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}') + explanation = get_pred(explanation, force=False) + save_explanation(explanation, explanation_path) + for apply_relu in [True, False]: + for apply_absolute in [True, False]: + adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) + save_raw_path = os.path.join( + global_path, f"adjust-{obj_config_to_str(adjust)}" + ) makedirs(save_raw_path) - explanation = adjust.forward(explanation) + explanation = adjust.forward(explanation) explanation_path = os.path.join(save_raw_path, f"{index}.json") - explanation = get_pred(explanation,force=True) - save_explanation(explanation,explanation_path) + explanation = get_pred(explanation, force=True) + save_explanation(explanation, explanation_path) - for threshold_approach in ['hard','topk','topk_hard']: - for threshold_value in explaining_cfg.threshold_config.value: - - masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}') - exp_threshold_path = os.path.join(masking_path,f'{index}.json') + for threshold_approach in ["hard", "topk", "topk_hard"]: + for threshold_value in explaining_cfg.threshold_config.value: + masking_path = os.path.join( + save_raw_path, + f"threshold={threshold_approach}-value={value}", + ) + exp_threshold_path = os.path.join(masking_path, f"{index}.json") if is_exists(exp_threshold_path): explanation = load_explanation(exp_threshold_path) else: - threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value} - explainer.threshold_config = ThresholdConfig.cast(threshold_conf) + threshold_conf = { + "threshold_type": threshold_approach, + "value": threshold_value, + } + explainer.threshold_config = ThresholdConfig.cast( + threshold_conf + ) expl = copy.copy(explanation) exp_threshold = explainer._post_process(expl) - exp_threshold= get_pred(exp_threshold,force=True) + exp_threshold = get_pred(exp_threshold, force=True) - save_explanation(exp_threshold,exp_threshold_path) + save_explanation(exp_threshold, exp_threshold_path) for metric in metrics: - metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}') - if is_exists(os.path.join(metric_path,f'{index}.json')): + metric_path = os.path.join( + masking_path, f"{obj_config_to_str(metric)}" + ) + if is_exists(os.path.join(metric_path, f"{index}.json")): continue else: out = metric.forward(exp_threshold) - write_json({f'{metric.name}':out}) - + write_json({f"{metric.name}": out})