Fixing some minors bugs
This commit is contained in:
parent
ad7fe83916
commit
1c62a5a561
|
@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
|
||||||
explaining_cfg.threshold.config.type = "all"
|
explaining_cfg.threshold.config.type = "all"
|
||||||
|
|
||||||
explaining_cfg.threshold.value = CN()
|
explaining_cfg.threshold.value = CN()
|
||||||
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)]
|
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(10)]
|
||||||
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
||||||
|
|
||||||
# which objectives metrics to computes, either all or one in particular if implemented
|
# which objectives metrics to computes, either all or one in particular if implemented
|
||||||
|
|
|
@ -534,29 +534,30 @@ class ExplainingOutline(object):
|
||||||
self.graphstat = GraphStat()
|
self.graphstat = GraphStat()
|
||||||
|
|
||||||
def get_explanation(self, item: Data, path: str):
|
def get_explanation(self, item: Data, path: str):
|
||||||
if is_exists(path):
|
if is_exists(
|
||||||
|
path,
|
||||||
|
):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
try:
|
try:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
if explanation is None:
|
if explanation is None:
|
||||||
logging.error(
|
logging.error(
|
||||||
" EXP::Generated; Path %s; FAILED",
|
" EXP::Generated; Path %s; FAILED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"EXP::Generated; Path %s; SUCCEEDED",
|
"EXP::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
explanation = _load_explanation(path)
|
explanation = _load_explanation(
|
||||||
|
path,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"EXP::Loaded; Path %s; SUCCEEDED",
|
"EXP::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
explanation = explanation.to(self.cfg.accelerator)
|
explanation = explanation.to(self.cfg.accelerator)
|
||||||
else:
|
else:
|
||||||
|
@ -565,8 +566,7 @@ class ExplainingOutline(object):
|
||||||
get_pred(self.explainer, explanation)
|
get_pred(self.explainer, explanation)
|
||||||
_save_explanation(explanation, path)
|
_save_explanation(explanation, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"EXP::Generated; Path %s; SUCCEEDED",
|
"EXP::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
|
@ -575,19 +575,21 @@ class ExplainingOutline(object):
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
||||||
if is_exists(path):
|
if is_exists(
|
||||||
|
path,
|
||||||
|
):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
exp_adjust = adjust.forward(item)
|
exp_adjust = adjust.forward(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ADJUST::Generated; Path %s; SUCCEEDED",
|
"ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
exp_adjust = _load_explanation(path)
|
exp_adjust = _load_explanation(
|
||||||
|
path,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ADJUST::Loaded; Path %s; SUCCEEDED",
|
"ADJUST::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -595,75 +597,76 @@ class ExplainingOutline(object):
|
||||||
get_pred(self.explainer, exp_adjust)
|
get_pred(self.explainer, exp_adjust)
|
||||||
_save_explanation(exp_adjust, path)
|
_save_explanation(exp_adjust, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ADJUST::Generated; Path %s; SUCCEEDED",
|
"ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
return exp_adjust
|
return exp_adjust
|
||||||
|
|
||||||
def get_threshold(self, item: Explanation, path: str):
|
def get_threshold(self, item: Explanation, path: str):
|
||||||
if is_exists(path):
|
if is_exists(
|
||||||
|
path,
|
||||||
|
):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
exp_threshold = self.explainer._post_process(item)
|
exp_threshold = self.explainer._post_process(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
"THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exp_threshold = _load_explanation(path)
|
exp_threshold = _load_explanation(
|
||||||
|
path,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"THRESHOLD::Loaded; Path %s; SUCCEEDED",
|
"THRESHOLD::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
exp_threshold = self.explainer._post_process(item)
|
exp_threshold = self.explainer._post_process(item)
|
||||||
get_pred(self.explainer, exp_threshold)
|
get_pred(self.explainer, exp_threshold)
|
||||||
_save_explanation(exp_threshold, path)
|
_save_explanation(exp_threshold, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"THRESHOLD::Generated; Path %s; SUCCEEDED",
|
"THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
if is_empty_graph(exp_threshold):
|
if is_empty_graph(exp_threshold):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED",
|
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
return exp_threshold
|
return exp_threshold
|
||||||
|
|
||||||
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
||||||
if is_exists(path):
|
if is_exists(
|
||||||
|
path,
|
||||||
|
):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
out_metric = metric.forward(item)
|
out_metric = metric.forward(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC::Generated; Path %s; SUCCEEDED",
|
"METRIC::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_metric = read_json(path)
|
out_metric = read_json(
|
||||||
|
path,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC::Loaded; Path %s; SUCCEEDED",
|
"METRIC::Loaded; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_metric = metric.forward(item)
|
out_metric = metric.forward(item)
|
||||||
data = {f"{metric.name}": out_metric}
|
|
||||||
write_json(data, path)
|
|
||||||
if out_metric is None:
|
if out_metric is None:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC::Generated; Path %s; FAILED",
|
"METRIC::Generated; Path %s; FAILED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"METRIC::Generated; Path %s; SUCCEEDED",
|
"METRIC::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
|
data = {f"{metric.name}": out_metric}
|
||||||
|
write_json(data, path)
|
||||||
return out_metric
|
return out_metric
|
||||||
|
|
||||||
def get_stat(self, item: Data, path: str):
|
def get_stat(self, item: Data, path: str):
|
||||||
if self.graphstat is None:
|
if self.graphstat is None:
|
||||||
self.load_graphstat()
|
self.load_graphstat()
|
||||||
if is_exists(path):
|
if is_exists(
|
||||||
|
path,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if item.num_nodes <= 500:
|
if item.num_nodes <= 500:
|
||||||
|
@ -671,30 +674,31 @@ class ExplainingOutline(object):
|
||||||
write_json(stat, path)
|
write_json(stat, path)
|
||||||
|
|
||||||
def get_attack(self, attack: Attack, item: Data, path: str):
|
def get_attack(self, attack: Attack, item: Data, path: str):
|
||||||
if is_exists(path):
|
if is_exists(
|
||||||
|
path,
|
||||||
|
):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
try:
|
try:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
data_attack = _load_explanation(path)
|
data_attack = _load_explanation(
|
||||||
|
path,
|
||||||
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
_save_explanation(data_attack, path)
|
_save_explanation(data_attack, path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
"ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
|
||||||
(path),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(str(e))
|
logging.error(str(e))
|
||||||
|
|
|
@ -2,11 +2,14 @@ import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
|
||||||
|
from explaining_framework.utils.io import read_json, write_json
|
||||||
|
|
||||||
|
|
||||||
def _get_explanation(explainer, item):
|
def _get_explanation(explainer, item):
|
||||||
explanation = explainer(
|
explanation = explainer(
|
||||||
|
@ -27,9 +30,7 @@ def is_empty_graph(data: Data) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def get_pred(explainer, explanation):
|
def get_pred(explainer, explanation):
|
||||||
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
|
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)
|
||||||
0
|
|
||||||
]
|
|
||||||
setattr(explanation, "pred", pred)
|
setattr(explanation, "pred", pred)
|
||||||
data = explanation.to_dict()
|
data = explanation.to_dict()
|
||||||
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
|
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
|
||||||
|
@ -38,7 +39,7 @@ def get_pred(explainer, explanation):
|
||||||
edge_index=explanation.edge_index,
|
edge_index=explanation.edge_index,
|
||||||
node_mask=data.get("node_mask"),
|
node_mask=data.get("node_mask"),
|
||||||
edge_mask=data.get("edge_mask"),
|
edge_mask=data.get("edge_mask"),
|
||||||
)[0]
|
)
|
||||||
setattr(explanation, "pred_exp", pred_masked)
|
setattr(explanation, "pred_exp", pred_masked)
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,15 +64,12 @@ def _save_explanation(exp: Explanation, path: str) -> None:
|
||||||
data = exp.clone().to_dict()
|
data = exp.clone().to_dict()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
data[k] = v.detach().cpu().tolist()
|
data[k] = v.clone().detach().cpu().tolist()
|
||||||
|
write_json(data, path)
|
||||||
with open(path, "w") as f:
|
|
||||||
json.dump(data, f)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_explanation(path: str) -> Explanation:
|
def _load_explanation(path: str) -> Explanation:
|
||||||
with open(path, "r") as f:
|
data = read_json(data, path)
|
||||||
data = json.load(f)
|
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
if k == "edge_index" or k == "y":
|
if k == "edge_index" or k == "y":
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from explaining_framework.config.explaining_config import explaining_cfg
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
|
|
||||||
|
|
||||||
|
@ -89,19 +90,23 @@ def set_printing(logger_path):
|
||||||
Set up printing options
|
Set up printing options
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logging.root.handlers = []
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
logging_cfg = {
|
formatter = logging.Formatter("%(asctime)s::%(levelname)s::%(message)s")
|
||||||
"level": logging.INFO,
|
|
||||||
"format": "%(asctime)s::%(levelname)s::%(message)s",
|
|
||||||
}
|
|
||||||
h_file = logging.FileHandler(logger_path)
|
h_file = logging.FileHandler(logger_path)
|
||||||
|
h_file.setLevel(logging.DEBUG)
|
||||||
|
h_file.setFormatter(formatter)
|
||||||
|
|
||||||
h_stdout = logging.StreamHandler(sys.stdout)
|
h_stdout = logging.StreamHandler(sys.stdout)
|
||||||
|
h_stdout.setLevel(logging.INFO)
|
||||||
|
h_stdout.setFormatter(formatter)
|
||||||
|
|
||||||
if explaining_cfg.print == "file":
|
if explaining_cfg.print == "file":
|
||||||
logging_cfg["handlers"] = [h_file]
|
logging.getLogger().addHandler(h_file)
|
||||||
elif explaining_cfg.print == "stdout":
|
elif explaining_cfg.print == "stdout":
|
||||||
logging_cfg["handlers"] = [h_stdout]
|
logging.getLogger().addHandler(h_stdout)
|
||||||
elif explaining_cfg.print == "both":
|
elif explaining_cfg.print == "both":
|
||||||
logging_cfg["handlers"] = [h_file, h_stdout]
|
logging.getLogger().addHandler(h_file)
|
||||||
|
logging.getLogger().addHandler(h_stdout)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Print option not supported")
|
raise ValueError("Print option not supported")
|
||||||
logging.basicConfig(**logging_cfg)
|
|
||||||
|
|
8
main.py
8
main.py
|
@ -27,7 +27,7 @@ if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||||
for item, index in tqdm(
|
for item, index in tqdm(
|
||||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
):
|
):
|
||||||
|
@ -42,7 +42,11 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
|
||||||
|
logging.info(
|
||||||
|
"Running %s: %s"
|
||||||
|
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
|
||||||
|
)
|
||||||
for item, index in tqdm(
|
for item, index in tqdm(
|
||||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue