Reformating, fixing

This commit is contained in:
araison 2023-01-04 10:41:34 +01:00
parent d9628ff947
commit 7e32d6fd3a
3 changed files with 145 additions and 110 deletions

View File

@ -57,7 +57,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.dataset.name = "Cora"
explaining_cfg.dataset.specific_items = None
explaining_cfg.dataset.items = None
explaining_cfg.run_topological_stat = True
@ -110,26 +110,33 @@ def set_cfg(explaining_cfg):
# Thresholding options
# ----------------------------------------------------------------------- #
explaining_cfg.threshold_config = CN()
explaining_cfg.threshold = CN()
explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold.config = CN()
explaining_cfg.threshold.config.type = "all"
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
explaining_cfg.threshold_config.relu_and_normalize = True
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.threshold.value = CN()
explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)]
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
# which objectives metrics to computes, either all or one in particular if implemented
explaining_cfg.metrics = CN()
explaining_cfg.metrics.name = "all"
explaining_cfg.metrics.sparsity = CN()
explaining_cfg.metrics.sparsity.name = "all"
explaining_cfg.metrics.fidelity = CN()
explaining_cfg.metrics.fidelity.name = "all"
explaining_cfg.metrics.accuracy = CN()
explaining_cfg.metrics.accuracy.name = "all"
# Whether or not recomputing metrics if they already exist
explaining_cfg.metrics.force = False
explaining_cfg.adjust = CN()
explaining_cfg.adjust.strategy = "rpn"
explaining_cfg.attack = CN()
explaining_cfg.attack.name = "all"
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto"

View File

@ -3,17 +3,6 @@ import itertools
from typing import Any
from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -22,6 +11,7 @@ from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
from explaining_framework.explainers.wrappers.from_graphxai import \
GraphXAIWrapper
from explaining_framework.metric.accuracy import Accuracy
from explaining_framework.metric.base import Metric
from explaining_framework.metric.fidelity import Fidelity
from explaining_framework.metric.robust import Attack
from explaining_framework.metric.sparsity import Sparsity
@ -34,6 +24,17 @@ from explaining_framework.utils.explanation.io import (
explanation_verification, get_pred)
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml)
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
all__captum = [
"LRP",
@ -120,6 +121,9 @@ class ExplainingOutline(object):
self.attacks = None
self.model_signature = None
self.indexes = None
self.sparsities = None
self.fidelities = None
self.accuracies = None
self.explaining_algorithm = None
self.explainer = None
self.adjusts = None
@ -248,8 +252,10 @@ class ExplainingOutline(object):
ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind]
def load_dataset_to_dataloader(self):
def load_dataset_to_dataloader(self, to_iter=True):
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
if to_iter:
self.dataset = iter(self.dataset)
def load_explaining_algorithm(self):
self.load_explainer_cfg()
@ -381,25 +387,27 @@ class ExplainingOutline(object):
]
elif name is None:
all_metrics = []
self.accuraties = all_metrics
else:
raise ValueError(
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
)
else:
all_metrics = []
self.accuracies = all_metrics
def load_metric(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
if self.accuraties is None:
if self.accuracies is None:
self.load_accuracy()
if self.sparsities is None:
self.load_sparsity()
if self.fidelities is None:
self.load_fidelity()
self.metrics = self.fidelities + self.accuraties + self.sparsities
self.metrics = self.fidelities + self.accuracies + self.sparsities
def load_attack(self):
if self.cfg is None:
@ -432,16 +440,18 @@ class ExplainingOutline(object):
strategy = self.explaining_cfg.adjust.strategy
if strategy == "all":
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
elif isinstance(name, str):
if name in all_adjusts_filters:
all_metrics = [Adjust(strategy=name)]
elif isinstance(strategy, str):
if strategy in all_adjusts_filters:
all_metrics = [Adjust(strategy=strategy)]
else:
raise ValueError(
f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}"
)
elif isinstance(name, list):
elif isinstance(strategy, list):
all_metrics = [
Adjust(strategy=name_) for name_ in name if name_ in all_robust
Adjust(strategy=name_)
for name_ in strategy
if name_ in all_adjusts_filters
]
elif name is None:
all_metrics = []
@ -450,7 +460,7 @@ class ExplainingOutline(object):
def load_threshold(self):
if self.explaining_cfg is None:
self.load_explaining_cfg()
threshold_type = self.explaining_cfg.threshold_config.type
threshold_type = self.explaining_cfg.threshold.config.type
if threshold_type == "all":
th_hard = [
{"threshold_type": "hard", "value": th_value}
@ -523,11 +533,50 @@ class ExplainingOutline(object):
explanation = _load_explanation(path)
else:
explanation = _get_explanation(self.explainer, item)
get_pred(self.explainer, explanation)
_save_explanation(explanation, path)
explanation = explanation.to(cfg.accelerator)
return explanation
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
exp_adjust = adjust.forward(item)
else:
exp_adjust = _load_explanation(path)
else:
exp_adjust = adjust.forward(item)
get_pred(self.explainer, exp_adjust)
_save_explanation(exp_adjust, path)
exp_adjust = exp_adjust.to(cfg.accelerator)
return exp_adjust
def get_threshold(self, item: Explanation, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
exp_threshold = self.explainer._post_process(item)
else:
exp_threshold = _load_explanation(path)
else:
exp_threshold = self.explainer._post_process(item)
get_pred(self.explainer, exp_threshold)
_save_explanation(exp_threshold, path)
exp_threshold = exp_threshold.to(cfg.accelerator)
return exp_threshold
def get_metric(self, metric: Metric, item: Explanation, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
out_metric = metric.forward(item)
else:
out_metric = read_json(path)
else:
out_metric = metric.forward(item)
data = {f"{metric.name}": out_metric}
write_json(data, path)
return out_metric
def get_stat(self, item: Data, path: str):
if self.graphstat is None:
self.load_graphstat()

109
main.py
View File

@ -18,9 +18,6 @@ from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.explanation.io import (
explanation_verification, get_explanation, get_pred, load_explanation,
save_explanation)
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml)
@ -31,81 +28,63 @@ if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file)
# Load components
# RAJOUTER INDEXES
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
makedirs(out_dir)
# Global path
global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm))
makedirs(global_path)
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
write_json(model_info, os.path.join(global_path, "info.json"))
write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml"))
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
makedirs(global_path)
write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
explainer_path = os.path.join(
out_dir,
outline.explaining_cfg.explainer.name
+ "_"
+ obj_config_to_str(outline.explaining_algorithm),
)
makedirs(explainer_path)
write_yaml(
outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest)
)
write_yaml(
outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml")
)
specific_explainer_path = os.path.join(
explainer_path, obj_config_to_str(outline.explaining_algorithm)
)
makedirs(specific_explainer_path)
raw_path = os.path.join(specific_explainer_path, "raw")
makedirs(raw_path)
global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm))
makedirs(global_path)
# SET UP EXPLAINER
# Save explaining configuration
item, index = outline.get_item()
while not (item is None or index is None):
raw_path = os.path.join(global_path, "raw")
makedirs(raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
for apply_relu in [True, False]:
for apply_absolute in [True, False]:
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
save_raw_path_ = os.path.join(
global_path, f"adjust-{obj_config_to_str(adjust)}"
explanation_path = os.path.join(raw_path, f"{index}.json")
raw_exp = outline.get_explanation(item=item, path=explanation_path)
for adjust in outline.adjusts:
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
makedirs(adjust_path)
exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json")
exp_adjust = outline.get_adjust(
adjust=adjust, item=raw_exp, path=exp_adjust_path
)
explanation__ = copy.copy(explanation).to(cfg.accelerator)
makedirs(save_raw_path_)
explanation = adjust.forward(explanation__)
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
get_pred(explainer, explanation__)
save_explanation(explanation__, explanation_path)
for threshold_approach in ["hard", "topk", "topk_hard"]:
if threshold_approach == "hard":
threshold_values = explaining_cfg.threshold_config.value
elif "topk" in threshold_approach:
threshold_values = [3, 5, 10, 20]
for threshold_value in threshold_values:
for threshold_conf in outline.thresholds_configs:
outline.set_explainer_threshold_config(threshold_conf)
masking_path = os.path.join(
save_raw_path_,
f"threshold={threshold_approach}-value={threshold_value}",
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
)
makedirs(masking_path)
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
if is_exists(exp_threshold_path):
exp_threshold = load_explanation(exp_threshold_path)
else:
threshold_conf = {
"threshold_type": threshold_approach,
"value": threshold_value,
}
explainer.threshold_config = ThresholdConfig.cast(
threshold_conf
exp_masked_path = os.path.join(masking_path, f"{index}.json")
exp_masked = outline.get_threshold(
item=exp_adjust, path=exp_masked_path
)
expl = copy.copy(explanation__).to(cfg.accelerator)
exp_threshold = explainer._post_process(expl)
exp_threshold = exp_threshold.to(cfg.accelerator)
get_pred(explainer, exp_threshold)
save_explanation(exp_threshold, exp_threshold_path)
for metric in metrics:
for metric in outline.metrics:
metric_path = os.path.join(
masking_path, f"{obj_config_to_str(metric)}"
)
makedirs(metric_path)
if is_exists(os.path.join(metric_path, f"{index}.json")):
continue
else:
out = metric.forward(exp_threshold)
write_json(
{f"{metric.name}": out},
os.path.join(metric_path, f"{index}.json"),
metric_path = os.path.join(metric_path, f"{index}.json")
out_metric = outline.get_metric(
metric=metric, item=exp_masked, path=metric_path
)