Fixing some minors bugs

This commit is contained in:
araison 2023-01-13 11:22:21 +01:00
parent ad7fe83916
commit 1c62a5a561
5 changed files with 82 additions and 71 deletions

View File

@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.threshold.config.type = "all" explaining_cfg.threshold.config.type = "all"
explaining_cfg.threshold.value = CN() explaining_cfg.threshold.value = CN()
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)] explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(10)]
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50] explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
# which objectives metrics to computes, either all or one in particular if implemented # which objectives metrics to computes, either all or one in particular if implemented

View File

@ -534,29 +534,30 @@ class ExplainingOutline(object):
self.graphstat = GraphStat() self.graphstat = GraphStat()
def get_explanation(self, item: Data, path: str): def get_explanation(self, item: Data, path: str):
if is_exists(path): if is_exists(
path,
):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
try: try:
explanation = _get_explanation(self.explainer, item) explanation = _get_explanation(self.explainer, item)
if explanation is None: if explanation is None:
logging.error( logging.error(
" EXP::Generated; Path %s; FAILED", " EXP::Generated; Path %s; FAILED" % (path,),
(path),
) )
else: else:
logging.debug( logging.debug(
"EXP::Generated; Path %s; SUCCEEDED", "EXP::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))
return None return None
else: else:
explanation = _load_explanation(path) explanation = _load_explanation(
path,
)
logging.debug( logging.debug(
"EXP::Loaded; Path %s; SUCCEEDED", "EXP::Loaded; Path %s; SUCCEEDED" % (path,),
(path),
) )
explanation = explanation.to(self.cfg.accelerator) explanation = explanation.to(self.cfg.accelerator)
else: else:
@ -565,8 +566,7 @@ class ExplainingOutline(object):
get_pred(self.explainer, explanation) get_pred(self.explainer, explanation)
_save_explanation(explanation, path) _save_explanation(explanation, path)
logging.debug( logging.debug(
"EXP::Generated; Path %s; SUCCEEDED", "EXP::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))
@ -575,19 +575,21 @@ class ExplainingOutline(object):
return explanation return explanation
def get_adjust(self, adjust: Adjust, item: Explanation, path: str): def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
if is_exists(path): if is_exists(
path,
):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item) exp_adjust = adjust.forward(item)
logging.debug( logging.debug(
"ADJUST::Generated; Path %s; SUCCEEDED", "ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
exp_adjust = _load_explanation(path) exp_adjust = _load_explanation(
path,
)
logging.debug( logging.debug(
"ADJUST::Loaded; Path %s; SUCCEEDED", "ADJUST::Loaded; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
@ -595,75 +597,76 @@ class ExplainingOutline(object):
get_pred(self.explainer, exp_adjust) get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path) _save_explanation(exp_adjust, path)
logging.debug( logging.debug(
"ADJUST::Generated; Path %s; SUCCEEDED", "ADJUST::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
return exp_adjust return exp_adjust
def get_threshold(self, item: Explanation, path: str): def get_threshold(self, item: Explanation, path: str):
if is_exists(path): if is_exists(
path,
):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item) exp_threshold = self.explainer._post_process(item)
logging.debug( logging.debug(
"THRESHOLD::Generated; Path %s; SUCCEEDED", "THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
exp_threshold = _load_explanation(path) exp_threshold = _load_explanation(
path,
)
logging.debug( logging.debug(
"THRESHOLD::Loaded; Path %s; SUCCEEDED", "THRESHOLD::Loaded; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
exp_threshold = self.explainer._post_process(item) exp_threshold = self.explainer._post_process(item)
get_pred(self.explainer, exp_threshold) get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path) _save_explanation(exp_threshold, path)
logging.debug( logging.debug(
"THRESHOLD::Generated; Path %s; SUCCEEDED", "THRESHOLD::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
if is_empty_graph(exp_threshold): if is_empty_graph(exp_threshold):
logging.warning( logging.warning(
"THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED", "THRESHOLD::Generated; Path %s; EMPTY GRAPH; FAILED" % (path,),
(path),
) )
return None return None
return exp_threshold return exp_threshold
def get_metric(self, metric: Metric, item: Explanation, path: str): def get_metric(self, metric: Metric, item: Explanation, path: str):
if is_exists(path): if is_exists(
path,
):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item) out_metric = metric.forward(item)
logging.debug( logging.debug(
"METRIC::Generated; Path %s; SUCCEEDED", "METRIC::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
out_metric = read_json(path) out_metric = read_json(
path,
)
logging.debug( logging.debug(
"METRIC::Loaded; Path %s; SUCCEEDED", "METRIC::Loaded; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
out_metric = metric.forward(item) out_metric = metric.forward(item)
data = {f"{metric.name}": out_metric}
write_json(data, path)
if out_metric is None: if out_metric is None:
logging.debug( logging.debug(
"METRIC::Generated; Path %s; FAILED", "METRIC::Generated; Path %s; FAILED" % (path,),
(path),
) )
else: else:
logging.debug( logging.debug(
"METRIC::Generated; Path %s; SUCCEEDED", "METRIC::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
data = {f"{metric.name}": out_metric}
write_json(data, path)
return out_metric return out_metric
def get_stat(self, item: Data, path: str): def get_stat(self, item: Data, path: str):
if self.graphstat is None: if self.graphstat is None:
self.load_graphstat() self.load_graphstat()
if is_exists(path): if is_exists(
path,
):
pass pass
else: else:
if item.num_nodes <= 500: if item.num_nodes <= 500:
@ -671,30 +674,31 @@ class ExplainingOutline(object):
write_json(stat, path) write_json(stat, path)
def get_attack(self, attack: Attack, item: Data, path: str): def get_attack(self, attack: Attack, item: Data, path: str):
if is_exists(path): if is_exists(
path,
):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
try: try:
data_attack = attack.get_attacked_prediction(item) data_attack = attack.get_attacked_prediction(item)
logging.debug( logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED", "ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))
return None return None
else: else:
data_attack = _load_explanation(path) data_attack = _load_explanation(
path,
)
logging.debug( logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED", "ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
else: else:
try: try:
data_attack = attack.get_attacked_prediction(item) data_attack = attack.get_attacked_prediction(item)
_save_explanation(data_attack, path) _save_explanation(data_attack, path)
logging.debug( logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED", "ATTACK::Generated; Path %s; SUCCEEDED" % (path,),
(path),
) )
except Exception as e: except Exception as e:
logging.error(str(e)) logging.error(str(e))

View File

@ -2,11 +2,14 @@ import copy
import json import json
import os import os
import numpy as np
import torch import torch
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.config import cfg
from explaining_framework.utils.io import read_json, write_json
def _get_explanation(explainer, item): def _get_explanation(explainer, item):
explanation = explainer( explanation = explainer(
@ -27,9 +30,7 @@ def is_empty_graph(data: Data) -> bool:
def get_pred(explainer, explanation): def get_pred(explainer, explanation):
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[ pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)
0
]
setattr(explanation, "pred", pred) setattr(explanation, "pred", pred)
data = explanation.to_dict() data = explanation.to_dict()
if not data.get("node_mask") is None or not data.get("edge_mask") is None: if not data.get("node_mask") is None or not data.get("edge_mask") is None:
@ -38,7 +39,7 @@ def get_pred(explainer, explanation):
edge_index=explanation.edge_index, edge_index=explanation.edge_index,
node_mask=data.get("node_mask"), node_mask=data.get("node_mask"),
edge_mask=data.get("edge_mask"), edge_mask=data.get("edge_mask"),
)[0] )
setattr(explanation, "pred_exp", pred_masked) setattr(explanation, "pred_exp", pred_masked)
@ -63,15 +64,12 @@ def _save_explanation(exp: Explanation, path: str) -> None:
data = exp.clone().to_dict() data = exp.clone().to_dict()
for k, v in data.items(): for k, v in data.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
data[k] = v.detach().cpu().tolist() data[k] = v.clone().detach().cpu().tolist()
write_json(data, path)
with open(path, "w") as f:
json.dump(data, f)
def _load_explanation(path: str) -> Explanation: def _load_explanation(path: str) -> Explanation:
with open(path, "r") as f: data = read_json(data, path)
data = json.load(f)
for k, v in data.items(): for k, v in data.items():
if isinstance(v, list): if isinstance(v, list):
if k == "edge_index" or k == "y": if k == "edge_index" or k == "y":

View File

@ -4,6 +4,7 @@ import os
import sys import sys
import yaml import yaml
from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.config.explaining_config import explaining_cfg
@ -89,19 +90,23 @@ def set_printing(logger_path):
Set up printing options Set up printing options
""" """
logging.root.handlers = [] logging.getLogger().setLevel(logging.DEBUG)
logging_cfg = { formatter = logging.Formatter("%(asctime)s::%(levelname)s::%(message)s")
"level": logging.INFO,
"format": "%(asctime)s::%(levelname)s::%(message)s",
}
h_file = logging.FileHandler(logger_path) h_file = logging.FileHandler(logger_path)
h_file.setLevel(logging.DEBUG)
h_file.setFormatter(formatter)
h_stdout = logging.StreamHandler(sys.stdout) h_stdout = logging.StreamHandler(sys.stdout)
h_stdout.setLevel(logging.INFO)
h_stdout.setFormatter(formatter)
if explaining_cfg.print == "file": if explaining_cfg.print == "file":
logging_cfg["handlers"] = [h_file] logging.getLogger().addHandler(h_file)
elif explaining_cfg.print == "stdout": elif explaining_cfg.print == "stdout":
logging_cfg["handlers"] = [h_stdout] logging.getLogger().addHandler(h_stdout)
elif explaining_cfg.print == "both": elif explaining_cfg.print == "both":
logging_cfg["handlers"] = [h_file, h_stdout] logging.getLogger().addHandler(h_file)
logging.getLogger().addHandler(h_stdout)
else: else:
raise ValueError("Print option not supported") raise ValueError("Print option not supported")
logging.basicConfig(**logging_cfg)

View File

@ -27,7 +27,7 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
for attack in outline.attacks: for attack in outline.attacks:
logging.info(f"Running {attack.__class__.__name__}: {attack.name}") logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
for item, index in tqdm( for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset) zip(outline.dataset, outline.indexes), total=len(outline.dataset)
): ):
@ -42,7 +42,11 @@ if __name__ == "__main__":
) )
for attack in outline.attacks: for attack in outline.attacks:
logging.info(f"Running {attack.__class__.__name__}: {attack.name}") logging.info("Running %s: %s" % (attack.__class__.__name__, attack.name))
logging.info(
"Running %s: %s"
% (outline.explainer.__class__.__name__, outline.explaining_algorithm.name),
)
for item, index in tqdm( for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset) zip(outline.dataset, outline.indexes), total=len(outline.dataset)
): ):