Fixing bug
This commit is contained in:
parent
074ff25c83
commit
a7dbba146e
79
main.py
79
main.py
|
@ -15,19 +15,25 @@ from torch_geometric.graphgym.utils.device import auto_select_device
|
|||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||
from explaining_framework.utils.explanation.adjust import Adjust
|
||||
from explaining_framework.utils.io import (obj_config_to_str, read_json,
|
||||
write_json, write_yaml)
|
||||
from explaining_framework.utils.explanation.adjust import Adjust
|
||||
|
||||
# inference, time, force,
|
||||
|
||||
def get_pred(explanation,force=False):
|
||||
|
||||
def get_pred(explanation, force=False):
|
||||
dict_ = explanation.to_dict()
|
||||
if dict_.get('pred') is None or dict_.get('pred_masked') or force:
|
||||
if dict_.get("pred") is None or dict_.get("pred_masked") or force:
|
||||
pred = explainer.get_prediction(explanation)
|
||||
pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask)
|
||||
explanation.__setattr__('pred',pred)
|
||||
explanation.__setattr__('pred_masked',pred_masked)
|
||||
pred_masked = explainer.get_masked_prediction(
|
||||
x=explanation.x,
|
||||
edge_index=explanation.edge_index,
|
||||
node_mask=explanation.node_mask,
|
||||
edge_mask=explanation.edge_mask,
|
||||
)
|
||||
explanation.__setattr__("pred", pred)
|
||||
explanation.__setattr__("pred_masked", pred_masked)
|
||||
return explanation
|
||||
else:
|
||||
return explanation
|
||||
|
@ -55,9 +61,12 @@ if __name__ == "__main__":
|
|||
global_path = os.path.join(explaining_cfg.out_dir, model_signature)
|
||||
makedirs(global_path)
|
||||
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
||||
write_json(model_info, os.path.join(global_path, "info.json"))
|
||||
write_json(model_info, os.path.join(global_path, "info.json"))
|
||||
|
||||
global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm))
|
||||
global_path = os.path.join(
|
||||
global_path,
|
||||
explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm),
|
||||
)
|
||||
makedirs(global_path)
|
||||
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
||||
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
||||
|
@ -101,40 +110,50 @@ if __name__ == "__main__":
|
|||
index=item.y,
|
||||
target=item.y,
|
||||
)
|
||||
explanation = get_pred(explanation,force=False)
|
||||
save_explanation(explanation,explanation_path)
|
||||
for apply_relu in [True,False]:
|
||||
for apply_absolute in [True,False]:
|
||||
adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute)
|
||||
save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}')
|
||||
explanation = get_pred(explanation, force=False)
|
||||
save_explanation(explanation, explanation_path)
|
||||
for apply_relu in [True, False]:
|
||||
for apply_absolute in [True, False]:
|
||||
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
||||
save_raw_path = os.path.join(
|
||||
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
||||
)
|
||||
makedirs(save_raw_path)
|
||||
explanation = adjust.forward(explanation)
|
||||
explanation = adjust.forward(explanation)
|
||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||
explanation = get_pred(explanation,force=True)
|
||||
save_explanation(explanation,explanation_path)
|
||||
explanation = get_pred(explanation, force=True)
|
||||
save_explanation(explanation, explanation_path)
|
||||
|
||||
for threshold_approach in ['hard','topk','topk_hard']:
|
||||
for threshold_value in explaining_cfg.threshold_config.value:
|
||||
|
||||
masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}')
|
||||
exp_threshold_path = os.path.join(masking_path,f'{index}.json')
|
||||
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
||||
for threshold_value in explaining_cfg.threshold_config.value:
|
||||
|
||||
masking_path = os.path.join(
|
||||
save_raw_path,
|
||||
f"threshold={threshold_approach}-value={value}",
|
||||
)
|
||||
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
||||
if is_exists(exp_threshold_path):
|
||||
explanation = load_explanation(exp_threshold_path)
|
||||
else:
|
||||
threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value}
|
||||
explainer.threshold_config = ThresholdConfig.cast(threshold_conf)
|
||||
threshold_conf = {
|
||||
"threshold_type": threshold_approach,
|
||||
"value": threshold_value,
|
||||
}
|
||||
explainer.threshold_config = ThresholdConfig.cast(
|
||||
threshold_conf
|
||||
)
|
||||
|
||||
expl = copy.copy(explanation)
|
||||
exp_threshold = explainer._post_process(expl)
|
||||
exp_threshold= get_pred(exp_threshold,force=True)
|
||||
exp_threshold = get_pred(exp_threshold, force=True)
|
||||
|
||||
save_explanation(exp_threshold,exp_threshold_path)
|
||||
save_explanation(exp_threshold, exp_threshold_path)
|
||||
for metric in metrics:
|
||||
metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}')
|
||||
if is_exists(os.path.join(metric_path,f'{index}.json')):
|
||||
metric_path = os.path.join(
|
||||
masking_path, f"{obj_config_to_str(metric)}"
|
||||
)
|
||||
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
||||
continue
|
||||
else:
|
||||
out = metric.forward(exp_threshold)
|
||||
write_json({f'{metric.name}':out})
|
||||
|
||||
write_json({f"{metric.name}": out})
|
||||
|
|
Loading…
Reference in New Issue