Fixing some minors bugs

This commit is contained in:
araison 2023-01-13 11:22:21 +01:00
parent ad7fe83916
commit 1c62a5a561
5 changed files with 82 additions and 71 deletions

View File

@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.threshold.config.type = "all"
explaining_cfg.threshold.value = CN()
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)]
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(10)]
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
# which objectives metrics to computes, either all or one in particular if implemented

View File

@ -534,29 +534,30 @@ class ExplainingOutline(object):
self.graphstat = GraphStat()
def get_explanation(self, item: Data, path: str):
if is_exists(path):
if is_exists(
path,
):
if self.explaining_cfg.explainer.force:
try:
explanation = _get_explanation(self.explainer, item)
if explanation is None:
logging.error(
" EXP::Generated; Path %s; FAILED",
(path),
" EXP::Generated; Path %s; FAILED" % (path,),
)
else:
logging.debug(
"EXP::Generated; Path %s; SUCCEEDED",
(path),
"EXP::Generated; Path %s; SUCCEEDED" % (path,),
)
except Exception as e:
logging.error(str(e))
return None
else:
explanation = _load_explanation(path)
explanation = _load_explanation(
path,
)
logging.debug(
"EXP::Loaded; Path %s; SUCCEEDED",
(path),
"EXP::Loaded; Path %s; SUCCEEDED" % (path,),
)
explanation = explanation.to(self.cfg.accelerator)
else:
@ -565,8 +566,7 @@ class ExplainingOutline(object):
get_pred(self.explainer, explanation)
_save_explanation(explanation, path)
logging.debug(
"EXP::Generated; Path %s; SUCCEEDED",
(path),
"EXP::Generated; Path %s; SUCCEEDED" % (path,),
)
except Exception as e:
logging.error(str(e))
@ -575,19 +575,21 @@ class ExplainingOutline(object):
return explanation
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
if is_exists(path):
if is_exists(
path,
):
if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item)
logging.debug(
"ADJUST::Generated; Path %s; SUCCEEDED",
(path),
"ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
)
else:
exp_adjust = _load_explanation(path)
exp_adjust = _load_explanation(
path,
)
logging.debug(
"ADJUST::Loaded; Path %s; SUCCEEDED",
(path),
"ADJUST::Loaded; Path %s; SUCCEEDED" % (path,),
)
else:
@ -595,75 +597,76 @@ class ExplainingOutline(object):
get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path)
logging.debug(
"ADJUST::Generated; Path %s; SUCCEEDED",
(path),
"ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
)
return exp_adjust
def get_threshold(self, item: Explanation, path: str):
if is_exists(path):
if is_exists(
path,
):
if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item)
logging.debug(
"THRESHOLD::Generated; Path %s; SUCCEEDED",
(path),
"THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
)
else:
exp_threshold = _load_explanation(path)
exp_threshold = _load_explanation(
path,
)
logging.debug(
"THRESHOLD::Loaded; Path %s; SUCCEEDED",
(path),
"THRESHOLD::Loaded; Path %s; SUCCEEDED" % (path,),
)
else:
exp_threshold = self.explainer._post_process(item)
get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path)
logging.debug(
"THRESHOLD::Generated; Path %s; SUCCEEDED",
(path),
"THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
)
if is_empty_graph(exp_threshold):
logging.warning(
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
(path),
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED" % (path,),
)
return None
return exp_threshold
def get_metric(self, metric: Metric, item: Explanation, path: str):
if is_exists(path):
if is_exists(
path,
):
if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item)
logging.debug(
"METRIC::Generated; Path %s; SUCCEEDED",
(path),
"METRIC::Generated; Path %s; SUCCEEDED" % (path,),
)
else:
out_metric = read_json(path)
out_metric = read_json(
path,
)
logging.debug(
"METRIC::Loaded; Path %s; SUCCEEDED",
(path),
"METRIC::Loaded; Path %s; SUCCEEDED" % (path,),
)
else:
out_metric = metric.forward(item)
data = {f"{metric.name}": out_metric}
write_json(data, path)
if out_metric is None:
logging.debug(
"METRIC::Generated; Path %s; FAILED",
(path),
"METRIC::Generated; Path %s; FAILED" % (path,),
)
else:
logging.debug(
"METRIC::Generated; Path %s; SUCCEEDED",
(path),
"METRIC::Generated; Path %s; SUCCEEDED" % (path,),
)
data = {f"{metric.name}": out_metric}
write_json(data, path)
return out_metric
def get_stat(self, item: Data, path: str):
if self.graphstat is None:
self.load_graphstat()
if is_exists(path):
if is_exists(
path,
):
pass
else:
if item.num_nodes <= 500:
@ -671,30 +674,31 @@ class ExplainingOutline(object):
write_json(stat, path)
def get_attack(self, attack: Attack, item: Data, path: str):
if is_exists(path):
if is_exists(
path,
):
if self.explaining_cfg.explainer.force:
try:
data_attack = attack.get_attacked_prediction(item)
logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
)
except Exception as e:
logging.error(str(e))
return None
else:
data_attack = _load_explanation(path)
data_attack = _load_explanation(
path,
)
logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
)
else:
try:
data_attack = attack.get_attacked_prediction(item)
_save_explanation(data_attack, path)
logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
)
except Exception as e:
logging.error(str(e))

View File

@ -2,11 +2,14 @@ import copy
import json
import os
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from explaining_framework.utils.io import read_json, write_json
def _get_explanation(explainer, item):
explanation = explainer(
@ -27,9 +30,7 @@ def is_empty_graph(data: Data) -> bool:
def get_pred(explainer, explanation):
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
0
]
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)
setattr(explanation, "pred", pred)
data = explanation.to_dict()
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
@ -38,7 +39,7 @@ def get_pred(explainer, explanation):
edge_index=explanation.edge_index,
node_mask=data.get("node_mask"),
edge_mask=data.get("edge_mask"),
)[0]
)
setattr(explanation, "pred_exp", pred_masked)
@ -63,15 +64,12 @@ def _save_explanation(exp: Explanation, path: str) -> None:
data = exp.clone().to_dict()
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.detach().cpu().tolist()
with open(path, "w") as f:
json.dump(data, f)
data[k] = v.clone().detach().cpu().tolist()
write_json(data, path)
def _load_explanation(path: str) -> Explanation:
with open(path, "r") as f:
data = json.load(f)
data = read_json(data, path)
for k, v in data.items():
if isinstance(v, list):
if k == "edge_index" or k == "y":

View File

@ -4,6 +4,7 @@ import os
import sys
import yaml
from explaining_framework.config.explaining_config import explaining_cfg
@ -89,19 +90,23 @@ def set_printing(logger_path):
Set up printing options
"""
logging.root.handlers = []
logging_cfg = {
"level": logging.INFO,
"format": "%(asctime)s::%(levelname)s::%(message)s",
}
logging.getLogger().setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s::%(levelname)s::%(message)s")
h_file = logging.FileHandler(logger_path)
h_file.setLevel(logging.DEBUG)
h_file.setFormatter(formatter)
h_stdout = logging.StreamHandler(sys.stdout)
h_stdout.setLevel(logging.INFO)
h_stdout.setFormatter(formatter)
if explaining_cfg.print == "file":
logging_cfg["handlers"] = [h_file]
logging.getLogger().addHandler(h_file)
elif explaining_cfg.print == "stdout":
logging_cfg["handlers"] = [h_stdout]
logging.getLogger().addHandler(h_stdout)
elif explaining_cfg.print == "both":
logging_cfg["handlers"] = [h_file, h_stdout]
logging.getLogger().addHandler(h_file)
logging.getLogger().addHandler(h_stdout)
else:
raise ValueError("Print option not supported")
logging.basicConfig(**logging_cfg)

View File

@ -27,7 +27,7 @@ if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
for attack in outline.attacks:
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
):
@ -42,7 +42,11 @@ if __name__ == "__main__":
)
for attack in outline.attacks:
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
logging.info(
"Running %s: %s"
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
)
for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
):