Reformating, fixing

This commit is contained in:
araison 2023-01-04 10:41:34 +01:00
parent d9628ff947
commit 7e32d6fd3a
3 changed files with 145 additions and 110 deletions

View File

@ -57,7 +57,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.dataset.name = "Cora" explaining_cfg.dataset.name = "Cora"
explaining_cfg.dataset.specific_items = None explaining_cfg.dataset.items = None
explaining_cfg.run_topological_stat = True explaining_cfg.run_topological_stat = True
@ -110,26 +110,33 @@ def set_cfg(explaining_cfg):
# Thresholding options # Thresholding options
# ----------------------------------------------------------------------- # # ----------------------------------------------------------------------- #
explaining_cfg.threshold_config = CN() explaining_cfg.threshold = CN()
explaining_cfg.threshold_config.threshold_type = None explaining_cfg.threshold.config = CN()
explaining_cfg.threshold.config.type = "all"
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)] explaining_cfg.threshold.value = CN()
explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)]
explaining_cfg.threshold_config.relu_and_normalize = True explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
# Select device: 'cpu', 'cuda', 'auto'
# which objectives metrics to computes, either all or one in particular if implemented # which objectives metrics to computes, either all or one in particular if implemented
explaining_cfg.metrics = CN() explaining_cfg.metrics = CN()
explaining_cfg.metrics.name = "all" explaining_cfg.metrics.sparsity = CN()
explaining_cfg.metrics.sparsity.name = "all"
explaining_cfg.metrics.fidelity = CN()
explaining_cfg.metrics.fidelity.name = "all"
explaining_cfg.metrics.accuracy = CN()
explaining_cfg.metrics.accuracy.name = "all"
# Whether or not recomputing metrics if they already exist # Whether or not recomputing metrics if they already exist
explaining_cfg.metrics.force = False
explaining_cfg.adjust = CN()
explaining_cfg.adjust.strategy = "rpn"
explaining_cfg.attack = CN() explaining_cfg.attack = CN()
explaining_cfg.attack.name = "all" explaining_cfg.attack.name = "all"
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto" explaining_cfg.accelerator = "auto"

View File

@ -3,17 +3,6 @@ import itertools
from typing import Any from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from explaining_framework.config.explainer_config.eixgnn_config import \ from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -22,6 +11,7 @@ from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
from explaining_framework.explainers.wrappers.from_graphxai import \ from explaining_framework.explainers.wrappers.from_graphxai import \
GraphXAIWrapper GraphXAIWrapper
from explaining_framework.metric.accuracy import Accuracy from explaining_framework.metric.accuracy import Accuracy
from explaining_framework.metric.base import Metric
from explaining_framework.metric.fidelity import Fidelity from explaining_framework.metric.fidelity import Fidelity
from explaining_framework.metric.robust import Attack from explaining_framework.metric.robust import Attack
from explaining_framework.metric.sparsity import Sparsity from explaining_framework.metric.sparsity import Sparsity
@ -34,6 +24,17 @@ from explaining_framework.utils.explanation.io import (
explanation_verification, get_pred) explanation_verification, get_pred)
from explaining_framework.utils.io import (is_exists, obj_config_to_str, from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml) read_json, write_json, write_yaml)
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
all__captum = [ all__captum = [
"LRP", "LRP",
@ -120,6 +121,9 @@ class ExplainingOutline(object):
self.attacks = None self.attacks = None
self.model_signature = None self.model_signature = None
self.indexes = None self.indexes = None
self.sparsities = None
self.fidelities = None
self.accuracies = None
self.explaining_algorithm = None self.explaining_algorithm = None
self.explainer = None self.explainer = None
self.adjusts = None self.adjusts = None
@ -248,8 +252,10 @@ class ExplainingOutline(object):
ind = self.explaining_cfg.dataset.specific_items ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind] self.dataset = self.dataset[ind]
def load_dataset_to_dataloader(self): def load_dataset_to_dataloader(self, to_iter=True):
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
if to_iter:
self.dataset = iter(self.dataset)
def load_explaining_algorithm(self): def load_explaining_algorithm(self):
self.load_explainer_cfg() self.load_explainer_cfg()
@ -381,25 +387,27 @@ class ExplainingOutline(object):
] ]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
self.accuraties = all_metrics else:
raise ValueError(
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
)
else: else:
raise ValueError( all_metrics = []
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset" self.accuracies = all_metrics
)
def load_metric(self): def load_metric(self):
if self.cfg is None: if self.cfg is None:
self.load_cfg() self.load_cfg()
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
if self.accuraties is None: if self.accuracies is None:
self.load_accuracy() self.load_accuracy()
if self.sparsities is None: if self.sparsities is None:
self.load_sparsity() self.load_sparsity()
if self.fidelities is None: if self.fidelities is None:
self.load_fidelity() self.load_fidelity()
self.metrics = self.fidelities + self.accuraties + self.sparsities self.metrics = self.fidelities + self.accuracies + self.sparsities
def load_attack(self): def load_attack(self):
if self.cfg is None: if self.cfg is None:
@ -432,16 +440,18 @@ class ExplainingOutline(object):
strategy = self.explaining_cfg.adjust.strategy strategy = self.explaining_cfg.adjust.strategy
if strategy == "all": if strategy == "all":
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters] self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
elif isinstance(name, str): elif isinstance(strategy, str):
if name in all_adjusts_filters: if strategy in all_adjusts_filters:
all_metrics = [Adjust(strategy=name)] all_metrics = [Adjust(strategy=strategy)]
else: else:
raise ValueError( raise ValueError(
f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}" f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}"
) )
elif isinstance(name, list): elif isinstance(strategy, list):
all_metrics = [ all_metrics = [
Adjust(strategy=name_) for name_ in name if name_ in all_robust Adjust(strategy=name_)
for name_ in strategy
if name_ in all_adjusts_filters
] ]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
@ -450,7 +460,7 @@ class ExplainingOutline(object):
def load_threshold(self): def load_threshold(self):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
threshold_type = self.explaining_cfg.threshold_config.type threshold_type = self.explaining_cfg.threshold.config.type
if threshold_type == "all": if threshold_type == "all":
th_hard = [ th_hard = [
{"threshold_type": "hard", "value": th_value} {"threshold_type": "hard", "value": th_value}
@ -523,11 +533,50 @@ class ExplainingOutline(object):
explanation = _load_explanation(path) explanation = _load_explanation(path)
else: else:
explanation = _get_explanation(self.explainer, item) explanation = _get_explanation(self.explainer, item)
get_pred(self.explainer, explanation)
_save_explanation(explanation, path) _save_explanation(explanation, path)
explanation = explanation.to(cfg.accelerator) explanation = explanation.to(cfg.accelerator)
return explanation return explanation
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item)
else:
exp_adjust = _load_explanation(path)
else:
exp_adjust = adjust.forward(item)
get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path)
exp_adjust = exp_adjust.to(cfg.accelerator)
return exp_adjust
def get_threshold(self, item: Explanation, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item)
else:
exp_threshold = _load_explanation(path)
else:
exp_threshold = self.explainer._post_process(item)
get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path)
exp_threshold = exp_threshold.to(cfg.accelerator)
return exp_threshold
def get_metric(self, metric: Metric, item: Explanation, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item)
else:
out_metric = read_json(path)
else:
out_metric = metric.forward(item)
data = {f"{metric.name}": out_metric}
write_json(data, path)
return out_metric
def get_stat(self, item: Data, path: str): def get_stat(self, item: Data, path: str):
if self.graphstat is None: if self.graphstat is None:
self.load_graphstat() self.load_graphstat()

131
main.py
View File

@ -18,9 +18,6 @@ from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.explanation.io import (
explanation_verification, get_explanation, get_pred, load_explanation,
save_explanation)
from explaining_framework.utils.io import (is_exists, obj_config_to_str, from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml) read_json, write_json, write_yaml)
@ -31,81 +28,63 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file) outline = ExplainingOutline(args.explaining_cfg_file)
# Load components out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
# RAJOUTER INDEXES makedirs(out_dir)
# Global path write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml"))
global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm)) write_json(outline.model_info, os.path.join(out_dir, "info.json"))
makedirs(global_path)
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
write_json(model_info, os.path.join(global_path, "info.json"))
makedirs(global_path) explainer_path = os.path.join(
write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) out_dir,
write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) outline.explaining_cfg.explainer.name
+ "_"
+ obj_config_to_str(outline.explaining_algorithm),
)
global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm)) makedirs(explainer_path)
makedirs(global_path) write_yaml(
# SET UP EXPLAINER outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest)
# Save explaining configuration )
item,index = outline.get_item() write_yaml(
while not(item is None or index is None): outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml")
raw_path = os.path.join(global_path, "raw") )
makedirs(raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
for apply_relu in [True, False]: specific_explainer_path = os.path.join(
for apply_absolute in [True, False]: explainer_path, obj_config_to_str(outline.explaining_algorithm)
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) )
save_raw_path_ = os.path.join( makedirs(specific_explainer_path)
global_path, f"adjust-{obj_config_to_str(adjust)}"
raw_path = os.path.join(specific_explainer_path, "raw")
makedirs(raw_path)
item, index = outline.get_item()
while not (item is None or index is None):
explanation_path = os.path.join(raw_path, f"{index}.json")
raw_exp = outline.get_explanation(item=item, path=explanation_path)
for adjust in outline.adjusts:
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
makedirs(adjust_path)
exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json")
exp_adjust = outline.get_adjust(
adjust=adjust, item=raw_exp, path=exp_adjust_path
)
for threshold_conf in outline.thresholds_configs:
outline.set_explainer_threshold_config(threshold_conf)
masking_path = os.path.join(
save_raw_path_,
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
) )
explanation__ = copy.copy(explanation).to(cfg.accelerator) makedirs(masking_path)
makedirs(save_raw_path_) exp_masked_path = os.path.join(masking_path, f"{index}.json")
explanation = adjust.forward(explanation__) exp_masked = outline.get_threshold(
explanation_path = os.path.join(save_raw_path_, f"{index}.json") item=exp_adjust, path=exp_masked_path
get_pred(explainer, explanation__) )
save_explanation(explanation__, explanation_path) for metric in outline.metrics:
metric_path = os.path.join(
for threshold_approach in ["hard", "topk", "topk_hard"]: masking_path, f"{obj_config_to_str(metric)}"
if threshold_approach == "hard": )
threshold_values = explaining_cfg.threshold_config.value makedirs(metric_path)
elif "topk" in threshold_approach: metric_path = os.path.join(metric_path, f"{index}.json")
threshold_values = [3, 5, 10, 20] out_metric = outline.get_metric(
for threshold_value in threshold_values: metric=metric, item=exp_masked, path=metric_path
)
masking_path = os.path.join(
save_raw_path_,
f"threshold={threshold_approach}-value={threshold_value}",
)
makedirs(masking_path)
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
if is_exists(exp_threshold_path):
exp_threshold = load_explanation(exp_threshold_path)
else:
threshold_conf = {
"threshold_type": threshold_approach,
"value": threshold_value,
}
explainer.threshold_config = ThresholdConfig.cast(
threshold_conf
)
expl = copy.copy(explanation__).to(cfg.accelerator)
exp_threshold = explainer._post_process(expl)
exp_threshold = exp_threshold.to(cfg.accelerator)
get_pred(explainer, exp_threshold)
save_explanation(exp_threshold, exp_threshold_path)
for metric in metrics:
metric_path = os.path.join(
masking_path, f"{obj_config_to_str(metric)}"
)
makedirs(metric_path)
if is_exists(os.path.join(metric_path, f"{index}.json")):
continue
else:
out = metric.forward(exp_threshold)
write_json(
{f"{metric.name}": out},
os.path.join(metric_path, f"{index}.json"),
)