diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 8fb333b..9e56b90 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -1,16 +1,8 @@ import copy +import itertools from typing import Any from eixgnn.eixgnn import EiXGNN -from scgnn.scgnn import SCGNN -from torch_geometric.data import Batch, Data -from torch_geometric.explain import Explainer -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader - from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -22,8 +14,19 @@ from explaining_framework.metric.accuracy import Accuracy from explaining_framework.metric.fidelity import Fidelity from explaining_framework.metric.robust import Attack from explaining_framework.metric.sparsity import Sparsity +from explaining_framework.stats.graph.graph_stat import GraphStat from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, _load_ckpt) +from explaining_framework.utils.explanation.adjust import Adjust +from scgnn.scgnn import SCGNN +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader all__captum = [ "LRP", @@ -85,6 +88,10 @@ all_robust = [ ] all_sparsity = ["l0"] +adjust_pattern = 'ranp' +all_adjusts_filters = [''.join(filters) for i in range(len(adjust_pattern)+1)for filters in itertools.permutations(adjust_pattern,i)] + +all_threshold_type = ['topk_hard','hard','topk'] class ExplainingOutline(object): def __init__(self, explaining_cfg_path: str): @@ -100,17 +107,65 @@ class ExplainingOutline(object): self.metrics = None self.attacks = None self.model_signature = None + self.indexes = None + self.explaining_algorithm = None + self.explainer = None + self.adjusts = None + self.thresholds_configs = None + self.graphstat = None self.load_explaining_cfg() self.load_model_info() self.load_cfg() self.load_dataset() self.load_model() + self.load_model_to_hardware() self.load_explainer_cfg() + self.load_explaining_algorithm() self.load_explainer() self.load_metric() self.load_attack() self.load_dataset_to_dataloader() + self.load_indexes() + self.load_adjust() + self.load_threshold() + self.load_graphstat() + + def load_model_to_hardware(self): + auto_select_device() + device = self.cfg.accelerator + self.model = self.model.to(device) + + def get_data(self): + if self.dataset is None: + self.load_dataset() + try: + item = next(self.dataset) + item = item.to(cfg.accelerator) + return item + except StopIteration: + return None + + def load_indexes(self): + if not self.explaining_cfg.dataset.specific_items is None: + indexes = explaining_cfg.dataset.specific_items + else: + indexes = list(range(len(self.dataset))) + self.indexes = iter(indexes) + + def get_index(self): + if self.indexes is None: + self.load_indexes() + try: + item = next(self.indexes) + return item + except StopIteration: + return None + + def get_item(self): + item = self.get_data() + index = self.get_index() + return item, index def load_model_info(self): info = LoadModelInfo( @@ -160,6 +215,7 @@ class ExplainingOutline(object): self.model = _load_ckpt(self.model, self.model_info["ckpt_path"]) if self.model is None: raise ValueError("Model ckpt has not been loaded, ckpt file not found") + self.model = self.model.eval() def load_dataset(self): if self.cfg is None: @@ -181,7 +237,7 @@ class ExplainingOutline(object): def load_dataset_to_dataloader(self): self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) - def load_explainer(self): + def load_explaining_algorithm(self): self.load_explainer_cfg() if self.model is None: self.load_model() @@ -219,54 +275,216 @@ class ExplainingOutline(object): raise ValueError(f"{name_} Metric is not supported yet") self.explaining_algorithm = explaining_algorithm + def load_explainer(self): + if self.explaining_algorithm is None: + self.load_explaining_algorithm() + explainer = Explainer( + model=self.model, + algorithm=self.explaining_algorithm, + explainer_config=dict( + explanation_type=self.explaining_cfg.explanation_type, + node_mask_type="object", + edge_mask_type="object", + ), + model_config=dict( + mode="regression", + task_level=self.cfg.dataset.task, + return_type=self.explaining_cfg.model_config.return_type, + ), + ) + self.explainer = explainer + + def load_fidelity(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.load_explaining_cfg() + name = self.explaining_cfg.metrics.fidelity.name + if name == 'all': + all_metrics = [ + Fidelity(name=name, model=self.model) for name in all_fidelity + ] + elif isinstance(name,str): + if name in all_fidelity: + all_metrics = [Fidelity(name=name, model=self.model)] + else: + raise ValueError(f'This fidelity metric {name} is nor supported yet. Supported are {all_fidelity}') + elif isinstance(name,list): + all_metrics = [Fidelity(name=name, model=self.model) for name_ in name if name_ in all_fidelity] + elif name is None: + all_metrics = [] + self.fidelities = all_metrics + + def load_sparsity(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.load_explaining_cfg() + name = self.explaining_cfg.metrics.sparsity.name + if name == 'all': + all_metrics = [ + Sparsity(name=name) for name in all_sparsity + ] + elif isinstance(name,str): + if name in all_sparsity: + all_metrics = [Sparsity(name=name)] + else: + raise ValueError(f'This sparsity metric {name} is nor supported yet. Supported are {all_sparsity}') + elif isinstance(name,list): + all_metrics = [Sparsity(name=name) for name_ in name if name_ in all_sparsity] + elif name is None: + all_metrics = [] + self.sparsities = all_metrics + + + + def load_accuracy(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.load_explaining_cfg() + + if self.explaining_cfg.dataset.name == "BASHAPES": + name = self.explaining_cfg.metrics.accuracy.name + if name == 'all': + all_metrics = [ + Accuracy(name=name) for name in all_accuracy + ] + elif isinstance(name,str): + if name in all_accuracy: + all_metrics = [Accuracy(name=name)] + else: + raise ValueError(f'This accuracy metric {name} is nor supported yet. Supported are {all_accuracy}') + elif isinstance(name,list): + all_metrics = [Accuracy(name=name) for name_ in name if name_ in all_accuracy] + elif name is None: + all_metrics = [] + self.accuraties = all_metrics + else: + raise ValueError(f'Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset') + + + def load_metric(self): if self.cfg is None: self.load_cfg() if self.explaining_cfg is None: self.load_explaining_cfg() + if self.accuraties is None: + self.load_accuracy() + if self.sparsities is None: + self.load_sparsity() + if self.fidelities is None: + self.load_fidelity() + + self.metrics = self.fidelities+self.accuraties+self.sparsities - name_ = self.explaining_cfg.metrics.name - - if name_ == "all": - all_fid_metrics = [ - Fidelity(name=name, model=self.model) for name in all_fidelity - ] - all_spa_metrics = [Sparsity(name) for name in all_sparsity] - self.metrics = all_spa_metrics + all_fid_metrics - - if self.explaining_cfg.dataset.name == "BASHAPES": - all_acc_metrics = [Accuracy(name) for name in all_accuracy] - self.metrics = self.metrics + all_acc_metrics - elif name_ in all_fidelity: - self.metrics = [Fidelity(name=name_, model=self.model)] - elif name_ in all_sparsity: - self.metrics = [Sparsity(name_)] - elif name_ in all_accuracy: - if self.explaining_cfg.dataset.name == "BASHAPES": - self.metrics = [Accuracy(name_)] - else: - raise ValueError( - f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation" - ) - elif name_ is None: - self.metrics = [] - else: - raise ValueError(f"{name_} Metric is not supported yet") - + def load_attack(self): if self.cfg is None: self.load_cfg() if self.explaining_cfg is None: self.load_explaining_cfg() - name_ = self.explaining_cfg.attack.name - if name_ == "all": - all_rob_metrics = [ - Attack(name=name, model=self.model) for name in all_robust + name = self.explaining_cfg.attack.name + if name == 'all': + all_metrics = [ + Attack(name=name,model=self.model) for name in all_robust ] - self.attacks = all_rob_metrics - elif name_ in all_robust: - self.attacks = [Attack(name=name_, model=self.model)] - elif name_ is None: - self.attacks = [] + elif isinstance(name,str): + if name in all_robust: + all_metrics = [Attack(name=name,model=self.model)] + else: + raise ValueError(f'This Attack metric {name} is not supported yet. Supported are {all_robust}') + elif isinstance(name,list): + all_metrics = [Attack(name=name,model=self.model) for name_ in name if name_ in all_robust] + elif name is None: + all_metrics = [] + self.attacks = all_metrics + + def load_adjust(self): + if self.explaining_cfg is None: + self.load_explaining_cfg() + strategy = self.explaining_cfg.adjust.strategy + if strategy == "all": + self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters] + elif isinstance(name,str): + if name in all_adjusts_filters: + all_metrics = [Adjust(strategy=name)] + else: + raise ValueError(f'This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}') + elif isinstance(name,list): + all_metrics = [Adjust(strategy=name_) for name_ in name if name_ in all_robust] + elif name is None: + all_metrics = [] + self.adjusts = all_metrics + + def load_threshold(self): + if self.explaining_cfg is None: + self.load_explaining_cfg() + threshold_type =self.explaining_cfg.threshold_config.type + if threshold_type == 'all': + th_hard = [{"threshold_type": 'hard',"value": th_value} for th_value in self.explaining_cfg.threshold.value.hard] + th_topk = [{"threshold_type": th_type,"value": th_value} for th_value in self.explaining_cfg.threshold.value.topk f or th_type in all_threshold_type if 'topk' in th_type] + all_threshold = th_hard + th_topk + elif isinstance(threshold_type,str): + if threshold_type in all_threshold_type: + if 'topk' in threshold_type: + all_threshold = [{ + "threshold_type": threshold_type, + "value": threshold_value, + } for threshold_value in self.explaining_cfg.threshold.value.topk] + elif threshold_type == 'hard': + all_threshold = [{ + "threshold_type": threshold_type, + "value": threshold_value, + } for threshold_value in self.explaining_cfg.threshold.value.hard] + elif isinstance(threshold_type,list): + all_threshold = [] + for tf_type in threshold_type: + if 'topk' in th_type: + all_threshold.expend([{ + "threshold_type": threshold_type, + "value": threshold_value, + } for threshold_value in self.explaining_cfg.threshold.value.topk]) + elif th_type == 'hard': + all_threshold.expend([{ + "threshold_type": threshold_type, + "value": threshold_value, + } for threshold_value in self.explaining_cfg.threshold.value.hard]) + + elif threshold_type is None: + all_threshold = [] + self.thresholds_configs = all_threshold + + def set_explainer_threshold_config(self,threshold_config): + self.explainer.threshold_config = ThresholdConfig.cast(threshold_config) + + def load_graphstat(self): + self.graphstat = GraphStat() + + def get_explanation_(self,item:Data,path:str): + if is_exists(path): + if self.explaining_cfg.explainer.force: + explanation = get_explanation(self.explainer, item) + else: + explanation = load_explanation(path) else: - raise ValueError(f"{name_} is an Attack method that is not supported yet") + explanation = get_explanation(explainer, item) + save_explanation(explanation,path) + return explanation + + +class Explaining(object): + def __init__(self,outline:ExplainingOutline): + self.outline = outline + + def run(self): + pass + + def explain(self): + item, index = self.get_item() + not_none = item is None or index is None + whœ + + while + diff --git a/explaining_framework/utils/explanation/adjust.py b/explaining_framework/utils/explanation/adjust.py index 726e65e..b012811 100644 --- a/explaining_framework/utils/explanation/adjust.py +++ b/explaining_framework/utils/explanation/adjust.py @@ -9,37 +9,29 @@ from torch_geometric.explain.explanation import Explanation class Adjust(object): def __init__( self, - apply_relu: bool = True, - apply_normalize: bool = True, - apply_project: bool = True, - apply_absolute: bool = False, + strategy: str = "rpn", ): - self.apply_relu = apply_relu - self.apply_normalize = apply_normalize - self.apply_project = apply_project - self.apply_absolute = apply_absolute - - if self.apply_absolute and self.apply_relu: - self.apply_relu = False + self.strategy = strategy def forward(self, exp: Explanation) -> Explanation: exp_ = copy.copy(exp) _store = exp_.to_dict() for k, v in _store.items(): if "mask" in k: - if self.apply_relu: - _store[k] = self.relu(v) - elif self.apply_absolute: - _store[k] = self.absolute(v) - elif self.apply_project: - if "edge" in k: - pass - else: - _store[k] = self.project(v) - elif self.apply_normalize: - _store[k] = self.normalize(v) - else: - continue + for f_ in self.strategy: + if f_ == "r": + _store[k] = self.relu(v) + if f_ == "a": + _store[k] = self.absolute(v) + if f_ == "p": + if "edge" in k: + pass + else: + _store[k] = self.project(v) + if f_ == "n": + _store[k] = self.normalize(v) + else: + continue return exp_ diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index ab66334..efd3a84 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -7,6 +7,38 @@ from torch_geometric.data import Data from torch_geometric.explain.explanation import Explanation +def get_explanation(explainer, item): + explanation = explainer( + x=item.x, + edge_index=item.edge_index, + index=int(item.y), + target=item.y, + ) + # TODO return None if pas bien plutot + assert explanation_verification(explanation) + return explanation + + +def is_empty_graph(data: Data) -> bool: + return data.x.shape[0] == 0 + + +def get_pred(explainer, explanation): + pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[ + 0 + ] + setattr(explanation, "pred", pred) + data = explanation.to_dict() + if not data.get("node_mask") is None or not data.get("edge_mask") is None: + pred_masked = explainer.get_masked_prediction( + x=explanation.x, + edge_index=explanation.edge_index, + node_mask=data.get("node_mask"), + edge_mask=data.get("edge_mask"), + )[0] + setattr(explanation, "pred_exp", pred_masked) + + def explanation_verification(exp: Explanation) -> bool: is_good = True masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)] @@ -53,5 +85,4 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation norm = torch.norm(input=data[k], p=p, dim=None).item() if norm.item() > 0: data[k] = data[k] / norm - return exp diff --git a/main.py b/main.py index 99be02b..75d9f71 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,8 @@ from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.io import ( - explanation_verification, load_explanation, save_explanation) + explanation_verification, get_explanation, get_pred, load_explanation, + save_explanation) from explaining_framework.utils.io import (is_exists, obj_config_to_str, read_json, write_json, write_yaml) @@ -42,17 +43,6 @@ def get_pred(explainer, explanation): setattr(explanation, "pred_exp", pred_masked) -def get_explanation(explainer, item): - explanation = explainer( - x=item.x, - edge_index=item.edge_index, - index=int(item.y), - target=item.y, - ) - assert explanation_verification(explanation) - return explanation - - if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) @@ -68,6 +58,7 @@ if __name__ == "__main__": attacks = outline.attacks explainer_cfg = outline.explainer_cfg model_signature = outline.model_signature + # RAJOUTER INDEXES # Set seed seed_everything(explaining_cfg.seed) @@ -77,6 +68,7 @@ if __name__ == "__main__": makedirs(global_path) write_yaml(cfg, os.path.join(global_path, "config.yaml")) write_json(model_info, os.path.join(global_path, "info.json")) + # SET RUN DIR global_path = os.path.join( global_path, @@ -85,9 +77,11 @@ if __name__ == "__main__": makedirs(global_path) write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) + # SET EXPLAIN_DIR global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) makedirs(global_path) + # SET UP EXPLAINER explainer = Explainer( model=model, algorithm=explaining_algorithm, @@ -102,6 +96,7 @@ if __name__ == "__main__": return_type=explaining_cfg.model_config.return_type, ), ) + # CHERGER SUR LE GPU DIRECT if not explaining_cfg.dataset.specific_items is None: indexes = explaining_cfg.dataset.specific_items else: