Reformating, fixing
This commit is contained in:
parent
d9628ff947
commit
7e32d6fd3a
@ -57,7 +57,7 @@ def set_cfg(explaining_cfg):
|
||||
|
||||
explaining_cfg.dataset.name = "Cora"
|
||||
|
||||
explaining_cfg.dataset.specific_items = None
|
||||
explaining_cfg.dataset.items = None
|
||||
|
||||
explaining_cfg.run_topological_stat = True
|
||||
|
||||
@ -110,26 +110,33 @@ def set_cfg(explaining_cfg):
|
||||
# Thresholding options
|
||||
# ----------------------------------------------------------------------- #
|
||||
|
||||
explaining_cfg.threshold_config = CN()
|
||||
explaining_cfg.threshold = CN()
|
||||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
explaining_cfg.threshold.config = CN()
|
||||
explaining_cfg.threshold.config.type = "all"
|
||||
|
||||
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
|
||||
|
||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.threshold.value = CN()
|
||||
explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)]
|
||||
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
||||
|
||||
# which objectives metrics to computes, either all or one in particular if implemented
|
||||
explaining_cfg.metrics = CN()
|
||||
explaining_cfg.metrics.name = "all"
|
||||
explaining_cfg.metrics.sparsity = CN()
|
||||
explaining_cfg.metrics.sparsity.name = "all"
|
||||
explaining_cfg.metrics.fidelity = CN()
|
||||
explaining_cfg.metrics.fidelity.name = "all"
|
||||
explaining_cfg.metrics.accuracy = CN()
|
||||
explaining_cfg.metrics.accuracy.name = "all"
|
||||
|
||||
# Whether or not recomputing metrics if they already exist
|
||||
explaining_cfg.metrics.force = False
|
||||
|
||||
explaining_cfg.adjust = CN()
|
||||
explaining_cfg.adjust.strategy = "rpn"
|
||||
|
||||
explaining_cfg.attack = CN()
|
||||
explaining_cfg.attack.name = "all"
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
|
||||
|
@ -3,17 +3,6 @@ import itertools
|
||||
from typing import Any
|
||||
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||
@ -22,6 +11,7 @@ from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||
GraphXAIWrapper
|
||||
from explaining_framework.metric.accuracy import Accuracy
|
||||
from explaining_framework.metric.base import Metric
|
||||
from explaining_framework.metric.fidelity import Fidelity
|
||||
from explaining_framework.metric.robust import Attack
|
||||
from explaining_framework.metric.sparsity import Sparsity
|
||||
@ -34,6 +24,17 @@ from explaining_framework.utils.explanation.io import (
|
||||
explanation_verification, get_pred)
|
||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||
read_json, write_json, write_yaml)
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
|
||||
all__captum = [
|
||||
"LRP",
|
||||
@ -120,6 +121,9 @@ class ExplainingOutline(object):
|
||||
self.attacks = None
|
||||
self.model_signature = None
|
||||
self.indexes = None
|
||||
self.sparsities = None
|
||||
self.fidelities = None
|
||||
self.accuracies = None
|
||||
self.explaining_algorithm = None
|
||||
self.explainer = None
|
||||
self.adjusts = None
|
||||
@ -248,8 +252,10 @@ class ExplainingOutline(object):
|
||||
ind = self.explaining_cfg.dataset.specific_items
|
||||
self.dataset = self.dataset[ind]
|
||||
|
||||
def load_dataset_to_dataloader(self):
|
||||
def load_dataset_to_dataloader(self, to_iter=True):
|
||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||
if to_iter:
|
||||
self.dataset = iter(self.dataset)
|
||||
|
||||
def load_explaining_algorithm(self):
|
||||
self.load_explainer_cfg()
|
||||
@ -381,25 +387,27 @@ class ExplainingOutline(object):
|
||||
]
|
||||
elif name is None:
|
||||
all_metrics = []
|
||||
self.accuraties = all_metrics
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
|
||||
)
|
||||
all_metrics = []
|
||||
self.accuracies = all_metrics
|
||||
|
||||
def load_metric(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
if self.accuraties is None:
|
||||
if self.accuracies is None:
|
||||
self.load_accuracy()
|
||||
if self.sparsities is None:
|
||||
self.load_sparsity()
|
||||
if self.fidelities is None:
|
||||
self.load_fidelity()
|
||||
|
||||
self.metrics = self.fidelities + self.accuraties + self.sparsities
|
||||
self.metrics = self.fidelities + self.accuracies + self.sparsities
|
||||
|
||||
def load_attack(self):
|
||||
if self.cfg is None:
|
||||
@ -432,16 +440,18 @@ class ExplainingOutline(object):
|
||||
strategy = self.explaining_cfg.adjust.strategy
|
||||
if strategy == "all":
|
||||
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
|
||||
elif isinstance(name, str):
|
||||
if name in all_adjusts_filters:
|
||||
all_metrics = [Adjust(strategy=name)]
|
||||
elif isinstance(strategy, str):
|
||||
if strategy in all_adjusts_filters:
|
||||
all_metrics = [Adjust(strategy=strategy)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}"
|
||||
)
|
||||
elif isinstance(name, list):
|
||||
elif isinstance(strategy, list):
|
||||
all_metrics = [
|
||||
Adjust(strategy=name_) for name_ in name if name_ in all_robust
|
||||
Adjust(strategy=name_)
|
||||
for name_ in strategy
|
||||
if name_ in all_adjusts_filters
|
||||
]
|
||||
elif name is None:
|
||||
all_metrics = []
|
||||
@ -450,7 +460,7 @@ class ExplainingOutline(object):
|
||||
def load_threshold(self):
|
||||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
threshold_type = self.explaining_cfg.threshold_config.type
|
||||
threshold_type = self.explaining_cfg.threshold.config.type
|
||||
if threshold_type == "all":
|
||||
th_hard = [
|
||||
{"threshold_type": "hard", "value": th_value}
|
||||
@ -523,11 +533,50 @@ class ExplainingOutline(object):
|
||||
explanation = _load_explanation(path)
|
||||
else:
|
||||
explanation = _get_explanation(self.explainer, item)
|
||||
get_pred(self.explainer, explanation)
|
||||
_save_explanation(explanation, path)
|
||||
explanation = explanation.to(cfg.accelerator)
|
||||
|
||||
return explanation
|
||||
|
||||
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
||||
if is_exists(path):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
exp_adjust = adjust.forward(item)
|
||||
else:
|
||||
exp_adjust = _load_explanation(path)
|
||||
else:
|
||||
exp_adjust = adjust.forward(item)
|
||||
get_pred(self.explainer, exp_adjust)
|
||||
_save_explanation(exp_adjust, path)
|
||||
exp_adjust = exp_adjust.to(cfg.accelerator)
|
||||
return exp_adjust
|
||||
|
||||
def get_threshold(self, item: Explanation, path: str):
|
||||
if is_exists(path):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
exp_threshold = self.explainer._post_process(item)
|
||||
else:
|
||||
exp_threshold = _load_explanation(path)
|
||||
else:
|
||||
exp_threshold = self.explainer._post_process(item)
|
||||
get_pred(self.explainer, exp_threshold)
|
||||
_save_explanation(exp_threshold, path)
|
||||
exp_threshold = exp_threshold.to(cfg.accelerator)
|
||||
return exp_threshold
|
||||
|
||||
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
||||
if is_exists(path):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
out_metric = metric.forward(item)
|
||||
else:
|
||||
out_metric = read_json(path)
|
||||
else:
|
||||
out_metric = metric.forward(item)
|
||||
data = {f"{metric.name}": out_metric}
|
||||
write_json(data, path)
|
||||
return out_metric
|
||||
|
||||
def get_stat(self, item: Data, path: str):
|
||||
if self.graphstat is None:
|
||||
self.load_graphstat()
|
||||
|
131
main.py
131
main.py
@ -18,9 +18,6 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||
from explaining_framework.utils.explanation.adjust import Adjust
|
||||
from explaining_framework.utils.explanation.io import (
|
||||
explanation_verification, get_explanation, get_pred, load_explanation,
|
||||
save_explanation)
|
||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||
read_json, write_json, write_yaml)
|
||||
|
||||
@ -31,81 +28,63 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||
|
||||
# Load components
|
||||
# RAJOUTER INDEXES
|
||||
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
|
||||
makedirs(out_dir)
|
||||
|
||||
# Global path
|
||||
global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm))
|
||||
makedirs(global_path)
|
||||
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
||||
write_json(model_info, os.path.join(global_path, "info.json"))
|
||||
write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml"))
|
||||
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
|
||||
|
||||
makedirs(global_path)
|
||||
write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
||||
write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
||||
explainer_path = os.path.join(
|
||||
out_dir,
|
||||
outline.explaining_cfg.explainer.name
|
||||
+ "_"
|
||||
+ obj_config_to_str(outline.explaining_algorithm),
|
||||
)
|
||||
|
||||
global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm))
|
||||
makedirs(global_path)
|
||||
# SET UP EXPLAINER
|
||||
# Save explaining configuration
|
||||
item,index = outline.get_item()
|
||||
while not(item is None or index is None):
|
||||
raw_path = os.path.join(global_path, "raw")
|
||||
makedirs(raw_path)
|
||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||
makedirs(explainer_path)
|
||||
write_yaml(
|
||||
outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest)
|
||||
)
|
||||
write_yaml(
|
||||
outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml")
|
||||
)
|
||||
|
||||
for apply_relu in [True, False]:
|
||||
for apply_absolute in [True, False]:
|
||||
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
||||
save_raw_path_ = os.path.join(
|
||||
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
||||
specific_explainer_path = os.path.join(
|
||||
explainer_path, obj_config_to_str(outline.explaining_algorithm)
|
||||
)
|
||||
makedirs(specific_explainer_path)
|
||||
|
||||
raw_path = os.path.join(specific_explainer_path, "raw")
|
||||
makedirs(raw_path)
|
||||
|
||||
item, index = outline.get_item()
|
||||
while not (item is None or index is None):
|
||||
explanation_path = os.path.join(raw_path, f"{index}.json")
|
||||
raw_exp = outline.get_explanation(item=item, path=explanation_path)
|
||||
for adjust in outline.adjusts:
|
||||
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
|
||||
makedirs(adjust_path)
|
||||
exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json")
|
||||
exp_adjust = outline.get_adjust(
|
||||
adjust=adjust, item=raw_exp, path=exp_adjust_path
|
||||
)
|
||||
for threshold_conf in outline.thresholds_configs:
|
||||
outline.set_explainer_threshold_config(threshold_conf)
|
||||
masking_path = os.path.join(
|
||||
save_raw_path_,
|
||||
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
||||
)
|
||||
explanation__ = copy.copy(explanation).to(cfg.accelerator)
|
||||
makedirs(save_raw_path_)
|
||||
explanation = adjust.forward(explanation__)
|
||||
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
|
||||
get_pred(explainer, explanation__)
|
||||
save_explanation(explanation__, explanation_path)
|
||||
|
||||
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
||||
if threshold_approach == "hard":
|
||||
threshold_values = explaining_cfg.threshold_config.value
|
||||
elif "topk" in threshold_approach:
|
||||
threshold_values = [3, 5, 10, 20]
|
||||
for threshold_value in threshold_values:
|
||||
|
||||
masking_path = os.path.join(
|
||||
save_raw_path_,
|
||||
f"threshold={threshold_approach}-value={threshold_value}",
|
||||
)
|
||||
makedirs(masking_path)
|
||||
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
||||
if is_exists(exp_threshold_path):
|
||||
exp_threshold = load_explanation(exp_threshold_path)
|
||||
else:
|
||||
threshold_conf = {
|
||||
"threshold_type": threshold_approach,
|
||||
"value": threshold_value,
|
||||
}
|
||||
explainer.threshold_config = ThresholdConfig.cast(
|
||||
threshold_conf
|
||||
)
|
||||
|
||||
expl = copy.copy(explanation__).to(cfg.accelerator)
|
||||
exp_threshold = explainer._post_process(expl)
|
||||
exp_threshold = exp_threshold.to(cfg.accelerator)
|
||||
get_pred(explainer, exp_threshold)
|
||||
save_explanation(exp_threshold, exp_threshold_path)
|
||||
for metric in metrics:
|
||||
metric_path = os.path.join(
|
||||
masking_path, f"{obj_config_to_str(metric)}"
|
||||
)
|
||||
makedirs(metric_path)
|
||||
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
||||
continue
|
||||
else:
|
||||
out = metric.forward(exp_threshold)
|
||||
write_json(
|
||||
{f"{metric.name}": out},
|
||||
os.path.join(metric_path, f"{index}.json"),
|
||||
)
|
||||
makedirs(masking_path)
|
||||
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
||||
exp_masked = outline.get_threshold(
|
||||
item=exp_adjust, path=exp_masked_path
|
||||
)
|
||||
for metric in outline.metrics:
|
||||
metric_path = os.path.join(
|
||||
masking_path, f"{obj_config_to_str(metric)}"
|
||||
)
|
||||
makedirs(metric_path)
|
||||
metric_path = os.path.join(metric_path, f"{index}.json")
|
||||
out_metric = outline.get_metric(
|
||||
metric=metric, item=exp_masked, path=metric_path
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user