Fixing some minors bugs
This commit is contained in:
parent
ad7fe83916
commit
1c62a5a561
|
@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
|
|||
explaining_cfg.threshold.config.type = "all"
|
||||
|
||||
explaining_cfg.threshold.value = CN()
|
||||
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)]
|
||||
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(10)]
|
||||
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
||||
|
||||
# which objectives metrics to computes, either all or one in particular if implemented
|
||||
|
|
|
@ -534,29 +534,30 @@ class ExplainingOutline(object):
|
|||
self.graphstat = GraphStat()
|
||||
|
||||
def get_explanation(self, item: Data, path: str):
|
||||
if is_exists(path):
|
||||
if is_exists(
|
||||
path,
|
||||
):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
try:
|
||||
explanation = _get_explanation(self.explainer, item)
|
||||
if explanation is None:
|
||||
logging.error(
|
||||
" EXP::Generated; Path %s; FAILED",
|
||||
(path),
|
||||
" EXP::Generated; Path %s; FAILED" % (path,),
|
||||
)
|
||||
|
||||
else:
|
||||
logging.debug(
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"EXP::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return None
|
||||
else:
|
||||
explanation = _load_explanation(path)
|
||||
explanation = _load_explanation(
|
||||
path,
|
||||
)
|
||||
logging.debug(
|
||||
"EXP::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"EXP::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
explanation = explanation.to(self.cfg.accelerator)
|
||||
else:
|
||||
|
@ -565,8 +566,7 @@ class ExplainingOutline(object):
|
|||
get_pred(self.explainer, explanation)
|
||||
_save_explanation(explanation, path)
|
||||
logging.debug(
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"EXP::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
|
@ -575,19 +575,21 @@ class ExplainingOutline(object):
|
|||
return explanation
|
||||
|
||||
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
||||
if is_exists(path):
|
||||
if is_exists(
|
||||
path,
|
||||
):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
exp_adjust = adjust.forward(item)
|
||||
logging.debug(
|
||||
"ADJUST::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
|
||||
else:
|
||||
exp_adjust = _load_explanation(path)
|
||||
exp_adjust = _load_explanation(
|
||||
path,
|
||||
)
|
||||
logging.debug(
|
||||
"ADJUST::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"ADJUST::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -595,75 +597,76 @@ class ExplainingOutline(object):
|
|||
get_pred(self.explainer, exp_adjust)
|
||||
_save_explanation(exp_adjust, path)
|
||||
logging.debug(
|
||||
"ADJUST::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
return exp_adjust
|
||||
|
||||
def get_threshold(self, item: Explanation, path: str):
|
||||
if is_exists(path):
|
||||
if is_exists(
|
||||
path,
|
||||
):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
exp_threshold = self.explainer._post_process(item)
|
||||
logging.debug(
|
||||
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
else:
|
||||
exp_threshold = _load_explanation(path)
|
||||
exp_threshold = _load_explanation(
|
||||
path,
|
||||
)
|
||||
logging.debug(
|
||||
"THRESHOLD::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"THRESHOLD::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
else:
|
||||
exp_threshold = self.explainer._post_process(item)
|
||||
get_pred(self.explainer, exp_threshold)
|
||||
_save_explanation(exp_threshold, path)
|
||||
logging.debug(
|
||||
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
if is_empty_graph(exp_threshold):
|
||||
logging.warning(
|
||||
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
|
||||
(path),
|
||||
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED" % (path,),
|
||||
)
|
||||
return None
|
||||
return exp_threshold
|
||||
|
||||
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
||||
if is_exists(path):
|
||||
if is_exists(
|
||||
path,
|
||||
):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
out_metric = metric.forward(item)
|
||||
logging.debug(
|
||||
"METRIC::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"METRIC::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
else:
|
||||
out_metric = read_json(path)
|
||||
out_metric = read_json(
|
||||
path,
|
||||
)
|
||||
logging.debug(
|
||||
"METRIC::Loaded; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"METRIC::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
else:
|
||||
out_metric = metric.forward(item)
|
||||
data = {f"{metric.name}": out_metric}
|
||||
write_json(data, path)
|
||||
if out_metric is None:
|
||||
logging.debug(
|
||||
"METRIC::Generated; Path %s; FAILED",
|
||||
(path),
|
||||
"METRIC::Generated; Path %s; FAILED" % (path,),
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
"METRIC::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"METRIC::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
data = {f"{metric.name}": out_metric}
|
||||
write_json(data, path)
|
||||
return out_metric
|
||||
|
||||
def get_stat(self, item: Data, path: str):
|
||||
if self.graphstat is None:
|
||||
self.load_graphstat()
|
||||
if is_exists(path):
|
||||
if is_exists(
|
||||
path,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
if item.num_nodes <= 500:
|
||||
|
@ -671,30 +674,31 @@ class ExplainingOutline(object):
|
|||
write_json(stat, path)
|
||||
|
||||
def get_attack(self, attack: Attack, item: Data, path: str):
|
||||
if is_exists(path):
|
||||
if is_exists(
|
||||
path,
|
||||
):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
try:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return None
|
||||
else:
|
||||
data_attack = _load_explanation(path)
|
||||
data_attack = _load_explanation(
|
||||
path,
|
||||
)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
_save_explanation(data_attack, path)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
|
|
|
@ -2,11 +2,14 @@ import copy
|
|||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
|
||||
from explaining_framework.utils.io import read_json, write_json
|
||||
|
||||
|
||||
def _get_explanation(explainer, item):
|
||||
explanation = explainer(
|
||||
|
@ -27,9 +30,7 @@ def is_empty_graph(data: Data) -> bool:
|
|||
|
||||
|
||||
def get_pred(explainer, explanation):
|
||||
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
|
||||
0
|
||||
]
|
||||
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)
|
||||
setattr(explanation, "pred", pred)
|
||||
data = explanation.to_dict()
|
||||
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
|
||||
|
@ -38,7 +39,7 @@ def get_pred(explainer, explanation):
|
|||
edge_index=explanation.edge_index,
|
||||
node_mask=data.get("node_mask"),
|
||||
edge_mask=data.get("edge_mask"),
|
||||
)[0]
|
||||
)
|
||||
setattr(explanation, "pred_exp", pred_masked)
|
||||
|
||||
|
||||
|
@ -63,15 +64,12 @@ def _save_explanation(exp: Explanation, path: str) -> None:
|
|||
data = exp.clone().to_dict()
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
data[k] = v.detach().cpu().tolist()
|
||||
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
data[k] = v.clone().detach().cpu().tolist()
|
||||
write_json(data, path)
|
||||
|
||||
|
||||
def _load_explanation(path: str) -> Explanation:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
data = read_json(data, path)
|
||||
for k, v in data.items():
|
||||
if isinstance(v, list):
|
||||
if k == "edge_index" or k == "y":
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
|
||||
|
||||
|
@ -89,19 +90,23 @@ def set_printing(logger_path):
|
|||
Set up printing options
|
||||
|
||||
"""
|
||||
logging.root.handlers = []
|
||||
logging_cfg = {
|
||||
"level": logging.INFO,
|
||||
"format": "%(asctime)s::%(levelname)s::%(message)s",
|
||||
}
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("%(asctime)s::%(levelname)s::%(message)s")
|
||||
|
||||
h_file = logging.FileHandler(logger_path)
|
||||
h_file.setLevel(logging.DEBUG)
|
||||
h_file.setFormatter(formatter)
|
||||
|
||||
h_stdout = logging.StreamHandler(sys.stdout)
|
||||
h_stdout.setLevel(logging.INFO)
|
||||
h_stdout.setFormatter(formatter)
|
||||
|
||||
if explaining_cfg.print == "file":
|
||||
logging_cfg["handlers"] = [h_file]
|
||||
logging.getLogger().addHandler(h_file)
|
||||
elif explaining_cfg.print == "stdout":
|
||||
logging_cfg["handlers"] = [h_stdout]
|
||||
logging.getLogger().addHandler(h_stdout)
|
||||
elif explaining_cfg.print == "both":
|
||||
logging_cfg["handlers"] = [h_file, h_stdout]
|
||||
logging.getLogger().addHandler(h_file)
|
||||
logging.getLogger().addHandler(h_stdout)
|
||||
else:
|
||||
raise ValueError("Print option not supported")
|
||||
logging.basicConfig(**logging_cfg)
|
||||
|
|
8
main.py
8
main.py
|
@ -27,7 +27,7 @@ if __name__ == "__main__":
|
|||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||
for attack in outline.attacks:
|
||||
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
||||
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||
for item, index in tqdm(
|
||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||
):
|
||||
|
@ -42,7 +42,11 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
for attack in outline.attacks:
|
||||
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
||||
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||
logging.info(
|
||||
"Running %s: %s"
|
||||
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
|
||||
)
|
||||
for item, index in tqdm(
|
||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue