Reformating, fixing
This commit is contained in:
parent
d9628ff947
commit
7e32d6fd3a
|
@ -57,7 +57,7 @@ def set_cfg(explaining_cfg):
|
||||||
|
|
||||||
explaining_cfg.dataset.name = "Cora"
|
explaining_cfg.dataset.name = "Cora"
|
||||||
|
|
||||||
explaining_cfg.dataset.specific_items = None
|
explaining_cfg.dataset.items = None
|
||||||
|
|
||||||
explaining_cfg.run_topological_stat = True
|
explaining_cfg.run_topological_stat = True
|
||||||
|
|
||||||
|
@ -110,26 +110,33 @@ def set_cfg(explaining_cfg):
|
||||||
# Thresholding options
|
# Thresholding options
|
||||||
# ----------------------------------------------------------------------- #
|
# ----------------------------------------------------------------------- #
|
||||||
|
|
||||||
explaining_cfg.threshold_config = CN()
|
explaining_cfg.threshold = CN()
|
||||||
|
|
||||||
explaining_cfg.threshold_config.threshold_type = None
|
explaining_cfg.threshold.config = CN()
|
||||||
|
explaining_cfg.threshold.config.type = "all"
|
||||||
|
|
||||||
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
|
explaining_cfg.threshold.value = CN()
|
||||||
|
explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)]
|
||||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
||||||
|
|
||||||
# Select device: 'cpu', 'cuda', 'auto'
|
|
||||||
|
|
||||||
# which objectives metrics to computes, either all or one in particular if implemented
|
# which objectives metrics to computes, either all or one in particular if implemented
|
||||||
explaining_cfg.metrics = CN()
|
explaining_cfg.metrics = CN()
|
||||||
explaining_cfg.metrics.name = "all"
|
explaining_cfg.metrics.sparsity = CN()
|
||||||
|
explaining_cfg.metrics.sparsity.name = "all"
|
||||||
|
explaining_cfg.metrics.fidelity = CN()
|
||||||
|
explaining_cfg.metrics.fidelity.name = "all"
|
||||||
|
explaining_cfg.metrics.accuracy = CN()
|
||||||
|
explaining_cfg.metrics.accuracy.name = "all"
|
||||||
|
|
||||||
# Whether or not recomputing metrics if they already exist
|
# Whether or not recomputing metrics if they already exist
|
||||||
explaining_cfg.metrics.force = False
|
|
||||||
|
explaining_cfg.adjust = CN()
|
||||||
|
explaining_cfg.adjust.strategy = "rpn"
|
||||||
|
|
||||||
explaining_cfg.attack = CN()
|
explaining_cfg.attack = CN()
|
||||||
explaining_cfg.attack.name = "all"
|
explaining_cfg.attack.name = "all"
|
||||||
|
|
||||||
|
# Select device: 'cpu', 'cuda', 'auto'
|
||||||
explaining_cfg.accelerator = "auto"
|
explaining_cfg.accelerator = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,17 +3,6 @@ import itertools
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric import seed_everything
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -22,6 +11,7 @@ from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||||
from explaining_framework.explainers.wrappers.from_graphxai import \
|
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||||
GraphXAIWrapper
|
GraphXAIWrapper
|
||||||
from explaining_framework.metric.accuracy import Accuracy
|
from explaining_framework.metric.accuracy import Accuracy
|
||||||
|
from explaining_framework.metric.base import Metric
|
||||||
from explaining_framework.metric.fidelity import Fidelity
|
from explaining_framework.metric.fidelity import Fidelity
|
||||||
from explaining_framework.metric.robust import Attack
|
from explaining_framework.metric.robust import Attack
|
||||||
from explaining_framework.metric.sparsity import Sparsity
|
from explaining_framework.metric.sparsity import Sparsity
|
||||||
|
@ -34,6 +24,17 @@ from explaining_framework.utils.explanation.io import (
|
||||||
explanation_verification, get_pred)
|
explanation_verification, get_pred)
|
||||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
read_json, write_json, write_yaml)
|
read_json, write_json, write_yaml)
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -120,6 +121,9 @@ class ExplainingOutline(object):
|
||||||
self.attacks = None
|
self.attacks = None
|
||||||
self.model_signature = None
|
self.model_signature = None
|
||||||
self.indexes = None
|
self.indexes = None
|
||||||
|
self.sparsities = None
|
||||||
|
self.fidelities = None
|
||||||
|
self.accuracies = None
|
||||||
self.explaining_algorithm = None
|
self.explaining_algorithm = None
|
||||||
self.explainer = None
|
self.explainer = None
|
||||||
self.adjusts = None
|
self.adjusts = None
|
||||||
|
@ -248,8 +252,10 @@ class ExplainingOutline(object):
|
||||||
ind = self.explaining_cfg.dataset.specific_items
|
ind = self.explaining_cfg.dataset.specific_items
|
||||||
self.dataset = self.dataset[ind]
|
self.dataset = self.dataset[ind]
|
||||||
|
|
||||||
def load_dataset_to_dataloader(self):
|
def load_dataset_to_dataloader(self, to_iter=True):
|
||||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||||
|
if to_iter:
|
||||||
|
self.dataset = iter(self.dataset)
|
||||||
|
|
||||||
def load_explaining_algorithm(self):
|
def load_explaining_algorithm(self):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
|
@ -381,25 +387,27 @@ class ExplainingOutline(object):
|
||||||
]
|
]
|
||||||
elif name is None:
|
elif name is None:
|
||||||
all_metrics = []
|
all_metrics = []
|
||||||
self.accuraties = all_metrics
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
all_metrics = []
|
||||||
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
|
self.accuracies = all_metrics
|
||||||
)
|
|
||||||
|
|
||||||
def load_metric(self):
|
def load_metric(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
if self.accuraties is None:
|
if self.accuracies is None:
|
||||||
self.load_accuracy()
|
self.load_accuracy()
|
||||||
if self.sparsities is None:
|
if self.sparsities is None:
|
||||||
self.load_sparsity()
|
self.load_sparsity()
|
||||||
if self.fidelities is None:
|
if self.fidelities is None:
|
||||||
self.load_fidelity()
|
self.load_fidelity()
|
||||||
|
|
||||||
self.metrics = self.fidelities + self.accuraties + self.sparsities
|
self.metrics = self.fidelities + self.accuracies + self.sparsities
|
||||||
|
|
||||||
def load_attack(self):
|
def load_attack(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
|
@ -432,16 +440,18 @@ class ExplainingOutline(object):
|
||||||
strategy = self.explaining_cfg.adjust.strategy
|
strategy = self.explaining_cfg.adjust.strategy
|
||||||
if strategy == "all":
|
if strategy == "all":
|
||||||
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
|
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
|
||||||
elif isinstance(name, str):
|
elif isinstance(strategy, str):
|
||||||
if name in all_adjusts_filters:
|
if strategy in all_adjusts_filters:
|
||||||
all_metrics = [Adjust(strategy=name)]
|
all_metrics = [Adjust(strategy=strategy)]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}"
|
f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}"
|
||||||
)
|
)
|
||||||
elif isinstance(name, list):
|
elif isinstance(strategy, list):
|
||||||
all_metrics = [
|
all_metrics = [
|
||||||
Adjust(strategy=name_) for name_ in name if name_ in all_robust
|
Adjust(strategy=name_)
|
||||||
|
for name_ in strategy
|
||||||
|
if name_ in all_adjusts_filters
|
||||||
]
|
]
|
||||||
elif name is None:
|
elif name is None:
|
||||||
all_metrics = []
|
all_metrics = []
|
||||||
|
@ -450,7 +460,7 @@ class ExplainingOutline(object):
|
||||||
def load_threshold(self):
|
def load_threshold(self):
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
threshold_type = self.explaining_cfg.threshold_config.type
|
threshold_type = self.explaining_cfg.threshold.config.type
|
||||||
if threshold_type == "all":
|
if threshold_type == "all":
|
||||||
th_hard = [
|
th_hard = [
|
||||||
{"threshold_type": "hard", "value": th_value}
|
{"threshold_type": "hard", "value": th_value}
|
||||||
|
@ -523,11 +533,50 @@ class ExplainingOutline(object):
|
||||||
explanation = _load_explanation(path)
|
explanation = _load_explanation(path)
|
||||||
else:
|
else:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
|
get_pred(self.explainer, explanation)
|
||||||
_save_explanation(explanation, path)
|
_save_explanation(explanation, path)
|
||||||
explanation = explanation.to(cfg.accelerator)
|
explanation = explanation.to(cfg.accelerator)
|
||||||
|
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
|
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
||||||
|
if is_exists(path):
|
||||||
|
if self.explaining_cfg.explainer.force:
|
||||||
|
exp_adjust = adjust.forward(item)
|
||||||
|
else:
|
||||||
|
exp_adjust = _load_explanation(path)
|
||||||
|
else:
|
||||||
|
exp_adjust = adjust.forward(item)
|
||||||
|
get_pred(self.explainer, exp_adjust)
|
||||||
|
_save_explanation(exp_adjust, path)
|
||||||
|
exp_adjust = exp_adjust.to(cfg.accelerator)
|
||||||
|
return exp_adjust
|
||||||
|
|
||||||
|
def get_threshold(self, item: Explanation, path: str):
|
||||||
|
if is_exists(path):
|
||||||
|
if self.explaining_cfg.explainer.force:
|
||||||
|
exp_threshold = self.explainer._post_process(item)
|
||||||
|
else:
|
||||||
|
exp_threshold = _load_explanation(path)
|
||||||
|
else:
|
||||||
|
exp_threshold = self.explainer._post_process(item)
|
||||||
|
get_pred(self.explainer, exp_threshold)
|
||||||
|
_save_explanation(exp_threshold, path)
|
||||||
|
exp_threshold = exp_threshold.to(cfg.accelerator)
|
||||||
|
return exp_threshold
|
||||||
|
|
||||||
|
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
||||||
|
if is_exists(path):
|
||||||
|
if self.explaining_cfg.explainer.force:
|
||||||
|
out_metric = metric.forward(item)
|
||||||
|
else:
|
||||||
|
out_metric = read_json(path)
|
||||||
|
else:
|
||||||
|
out_metric = metric.forward(item)
|
||||||
|
data = {f"{metric.name}": out_metric}
|
||||||
|
write_json(data, path)
|
||||||
|
return out_metric
|
||||||
|
|
||||||
def get_stat(self, item: Data, path: str):
|
def get_stat(self, item: Data, path: str):
|
||||||
if self.graphstat is None:
|
if self.graphstat is None:
|
||||||
self.load_graphstat()
|
self.load_graphstat()
|
||||||
|
|
131
main.py
131
main.py
|
@ -18,9 +18,6 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.explanation.io import (
|
|
||||||
explanation_verification, get_explanation, get_pred, load_explanation,
|
|
||||||
save_explanation)
|
|
||||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
read_json, write_json, write_yaml)
|
read_json, write_json, write_yaml)
|
||||||
|
|
||||||
|
@ -31,81 +28,63 @@ if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
|
|
||||||
# Load components
|
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
|
||||||
# RAJOUTER INDEXES
|
makedirs(out_dir)
|
||||||
|
|
||||||
# Global path
|
write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml"))
|
||||||
global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm))
|
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
|
||||||
makedirs(global_path)
|
|
||||||
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
|
||||||
write_json(model_info, os.path.join(global_path, "info.json"))
|
|
||||||
|
|
||||||
makedirs(global_path)
|
explainer_path = os.path.join(
|
||||||
write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
out_dir,
|
||||||
write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
outline.explaining_cfg.explainer.name
|
||||||
|
+ "_"
|
||||||
|
+ obj_config_to_str(outline.explaining_algorithm),
|
||||||
|
)
|
||||||
|
|
||||||
global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm))
|
makedirs(explainer_path)
|
||||||
makedirs(global_path)
|
write_yaml(
|
||||||
# SET UP EXPLAINER
|
outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest)
|
||||||
# Save explaining configuration
|
)
|
||||||
item,index = outline.get_item()
|
write_yaml(
|
||||||
while not(item is None or index is None):
|
outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml")
|
||||||
raw_path = os.path.join(global_path, "raw")
|
)
|
||||||
makedirs(raw_path)
|
|
||||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
|
||||||
|
|
||||||
for apply_relu in [True, False]:
|
specific_explainer_path = os.path.join(
|
||||||
for apply_absolute in [True, False]:
|
explainer_path, obj_config_to_str(outline.explaining_algorithm)
|
||||||
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
)
|
||||||
save_raw_path_ = os.path.join(
|
makedirs(specific_explainer_path)
|
||||||
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
|
||||||
|
raw_path = os.path.join(specific_explainer_path, "raw")
|
||||||
|
makedirs(raw_path)
|
||||||
|
|
||||||
|
item, index = outline.get_item()
|
||||||
|
while not (item is None or index is None):
|
||||||
|
explanation_path = os.path.join(raw_path, f"{index}.json")
|
||||||
|
raw_exp = outline.get_explanation(item=item, path=explanation_path)
|
||||||
|
for adjust in outline.adjusts:
|
||||||
|
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
|
||||||
|
makedirs(adjust_path)
|
||||||
|
exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json")
|
||||||
|
exp_adjust = outline.get_adjust(
|
||||||
|
adjust=adjust, item=raw_exp, path=exp_adjust_path
|
||||||
|
)
|
||||||
|
for threshold_conf in outline.thresholds_configs:
|
||||||
|
outline.set_explainer_threshold_config(threshold_conf)
|
||||||
|
masking_path = os.path.join(
|
||||||
|
save_raw_path_,
|
||||||
|
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
||||||
)
|
)
|
||||||
explanation__ = copy.copy(explanation).to(cfg.accelerator)
|
makedirs(masking_path)
|
||||||
makedirs(save_raw_path_)
|
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
||||||
explanation = adjust.forward(explanation__)
|
exp_masked = outline.get_threshold(
|
||||||
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
|
item=exp_adjust, path=exp_masked_path
|
||||||
get_pred(explainer, explanation__)
|
)
|
||||||
save_explanation(explanation__, explanation_path)
|
for metric in outline.metrics:
|
||||||
|
metric_path = os.path.join(
|
||||||
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
masking_path, f"{obj_config_to_str(metric)}"
|
||||||
if threshold_approach == "hard":
|
)
|
||||||
threshold_values = explaining_cfg.threshold_config.value
|
makedirs(metric_path)
|
||||||
elif "topk" in threshold_approach:
|
metric_path = os.path.join(metric_path, f"{index}.json")
|
||||||
threshold_values = [3, 5, 10, 20]
|
out_metric = outline.get_metric(
|
||||||
for threshold_value in threshold_values:
|
metric=metric, item=exp_masked, path=metric_path
|
||||||
|
)
|
||||||
masking_path = os.path.join(
|
|
||||||
save_raw_path_,
|
|
||||||
f"threshold={threshold_approach}-value={threshold_value}",
|
|
||||||
)
|
|
||||||
makedirs(masking_path)
|
|
||||||
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
|
||||||
if is_exists(exp_threshold_path):
|
|
||||||
exp_threshold = load_explanation(exp_threshold_path)
|
|
||||||
else:
|
|
||||||
threshold_conf = {
|
|
||||||
"threshold_type": threshold_approach,
|
|
||||||
"value": threshold_value,
|
|
||||||
}
|
|
||||||
explainer.threshold_config = ThresholdConfig.cast(
|
|
||||||
threshold_conf
|
|
||||||
)
|
|
||||||
|
|
||||||
expl = copy.copy(explanation__).to(cfg.accelerator)
|
|
||||||
exp_threshold = explainer._post_process(expl)
|
|
||||||
exp_threshold = exp_threshold.to(cfg.accelerator)
|
|
||||||
get_pred(explainer, exp_threshold)
|
|
||||||
save_explanation(exp_threshold, exp_threshold_path)
|
|
||||||
for metric in metrics:
|
|
||||||
metric_path = os.path.join(
|
|
||||||
masking_path, f"{obj_config_to_str(metric)}"
|
|
||||||
)
|
|
||||||
makedirs(metric_path)
|
|
||||||
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
out = metric.forward(exp_threshold)
|
|
||||||
write_json(
|
|
||||||
{f"{metric.name}": out},
|
|
||||||
os.path.join(metric_path, f"{index}.json"),
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue