diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 9e56b90..b493b85 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -3,6 +3,17 @@ import itertools from typing import Any from eixgnn.eixgnn import EiXGNN +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader + from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -18,15 +29,11 @@ from explaining_framework.stats.graph.graph_stat import GraphStat from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, _load_ckpt) from explaining_framework.utils.explanation.adjust import Adjust -from scgnn.scgnn import SCGNN -from torch_geometric.data import Batch, Data -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader +from explaining_framework.utils.explanation.io import ( + _get_explanation, _load_explanation, _save_explanation, + explanation_verification, get_pred) +from explaining_framework.utils.io import (is_exists, obj_config_to_str, + read_json, write_json, write_yaml) all__captum = [ "LRP", @@ -88,10 +95,15 @@ all_robust = [ ] all_sparsity = ["l0"] -adjust_pattern = 'ranp' -all_adjusts_filters = [''.join(filters) for i in range(len(adjust_pattern)+1)for filters in itertools.permutations(adjust_pattern,i)] +adjust_pattern = "ranp" +all_adjusts_filters = [ + "".join(filters) + for i in range(len(adjust_pattern) + 1) + for filters in itertools.permutations(adjust_pattern, i) +] + +all_threshold_type = ["topk_hard", "hard", "topk"] -all_threshold_type = ['topk_hard','hard','topk'] class ExplainingOutline(object): def __init__(self, explaining_cfg_path: str): @@ -131,6 +143,8 @@ class ExplainingOutline(object): self.load_threshold() self.load_graphstat() + seed_everything(self.explaining_cfg.seed) + def load_model_to_hardware(self): auto_select_device() device = self.cfg.accelerator @@ -300,44 +314,50 @@ class ExplainingOutline(object): if self.explaining_cfg is None: self.load_explaining_cfg() name = self.explaining_cfg.metrics.fidelity.name - if name == 'all': + if name == "all": all_metrics = [ Fidelity(name=name, model=self.model) for name in all_fidelity ] - elif isinstance(name,str): + elif isinstance(name, str): if name in all_fidelity: all_metrics = [Fidelity(name=name, model=self.model)] else: - raise ValueError(f'This fidelity metric {name} is nor supported yet. Supported are {all_fidelity}') - elif isinstance(name,list): - all_metrics = [Fidelity(name=name, model=self.model) for name_ in name if name_ in all_fidelity] + raise ValueError( + f"This fidelity metric {name} is nor supported yet. Supported are {all_fidelity}" + ) + elif isinstance(name, list): + all_metrics = [ + Fidelity(name=name, model=self.model) + for name_ in name + if name_ in all_fidelity + ] elif name is None: all_metrics = [] self.fidelities = all_metrics - + def load_sparsity(self): if self.cfg is None: self.load_cfg() if self.explaining_cfg is None: self.load_explaining_cfg() name = self.explaining_cfg.metrics.sparsity.name - if name == 'all': - all_metrics = [ - Sparsity(name=name) for name in all_sparsity - ] - elif isinstance(name,str): + if name == "all": + all_metrics = [Sparsity(name=name) for name in all_sparsity] + elif isinstance(name, str): if name in all_sparsity: all_metrics = [Sparsity(name=name)] else: - raise ValueError(f'This sparsity metric {name} is nor supported yet. Supported are {all_sparsity}') - elif isinstance(name,list): - all_metrics = [Sparsity(name=name) for name_ in name if name_ in all_sparsity] + raise ValueError( + f"This sparsity metric {name} is nor supported yet. Supported are {all_sparsity}" + ) + elif isinstance(name, list): + all_metrics = [ + Sparsity(name=name) for name_ in name if name_ in all_sparsity + ] elif name is None: all_metrics = [] self.sparsities = all_metrics - - def load_accuracy(self): if self.cfg is None: self.load_cfg() @@ -346,24 +366,26 @@ class ExplainingOutline(object): if self.explaining_cfg.dataset.name == "BASHAPES": name = self.explaining_cfg.metrics.accuracy.name - if name == 'all': - all_metrics = [ - Accuracy(name=name) for name in all_accuracy - ] - elif isinstance(name,str): + if name == "all": + all_metrics = [Accuracy(name=name) for name in all_accuracy] + elif isinstance(name, str): if name in all_accuracy: all_metrics = [Accuracy(name=name)] else: - raise ValueError(f'This accuracy metric {name} is nor supported yet. Supported are {all_accuracy}') - elif isinstance(name,list): - all_metrics = [Accuracy(name=name) for name_ in name if name_ in all_accuracy] + raise ValueError( + f"This accuracy metric {name} is nor supported yet. Supported are {all_accuracy}" + ) + elif isinstance(name, list): + all_metrics = [ + Accuracy(name=name) for name_ in name if name_ in all_accuracy + ] elif name is None: all_metrics = [] self.accuraties = all_metrics else: - raise ValueError(f'Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset') - - + raise ValueError( + f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset" + ) def load_metric(self): if self.cfg is None: @@ -376,27 +398,30 @@ class ExplainingOutline(object): self.load_sparsity() if self.fidelities is None: self.load_fidelity() - - self.metrics = self.fidelities+self.accuraties+self.sparsities - + self.metrics = self.fidelities + self.accuraties + self.sparsities + def load_attack(self): if self.cfg is None: self.load_cfg() if self.explaining_cfg is None: self.load_explaining_cfg() name = self.explaining_cfg.attack.name - if name == 'all': - all_metrics = [ - Attack(name=name,model=self.model) for name in all_robust - ] - elif isinstance(name,str): + if name == "all": + all_metrics = [Attack(name=name, model=self.model) for name in all_robust] + elif isinstance(name, str): if name in all_robust: - all_metrics = [Attack(name=name,model=self.model)] + all_metrics = [Attack(name=name, model=self.model)] else: - raise ValueError(f'This Attack metric {name} is not supported yet. Supported are {all_robust}') - elif isinstance(name,list): - all_metrics = [Attack(name=name,model=self.model) for name_ in name if name_ in all_robust] + raise ValueError( + f"This Attack metric {name} is not supported yet. Supported are {all_robust}" + ) + elif isinstance(name, list): + all_metrics = [ + Attack(name=name, model=self.model) + for name_ in name + if name_ in all_robust + ] elif name is None: all_metrics = [] self.attacks = all_metrics @@ -407,13 +432,17 @@ class ExplainingOutline(object): strategy = self.explaining_cfg.adjust.strategy if strategy == "all": self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters] - elif isinstance(name,str): + elif isinstance(name, str): if name in all_adjusts_filters: all_metrics = [Adjust(strategy=name)] else: - raise ValueError(f'This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}') - elif isinstance(name,list): - all_metrics = [Adjust(strategy=name_) for name_ in name if name_ in all_robust] + raise ValueError( + f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}" + ) + elif isinstance(name, list): + all_metrics = [ + Adjust(strategy=name_) for name_ in name if name_ in all_robust + ] elif name is None: all_metrics = [] self.adjusts = all_metrics @@ -421,70 +450,90 @@ class ExplainingOutline(object): def load_threshold(self): if self.explaining_cfg is None: self.load_explaining_cfg() - threshold_type =self.explaining_cfg.threshold_config.type - if threshold_type == 'all': - th_hard = [{"threshold_type": 'hard',"value": th_value} for th_value in self.explaining_cfg.threshold.value.hard] - th_topk = [{"threshold_type": th_type,"value": th_value} for th_value in self.explaining_cfg.threshold.value.topk f or th_type in all_threshold_type if 'topk' in th_type] + threshold_type = self.explaining_cfg.threshold_config.type + if threshold_type == "all": + th_hard = [ + {"threshold_type": "hard", "value": th_value} + for th_value in self.explaining_cfg.threshold.value.hard + ] + th_topk = [ + {"threshold_type": th_type, "value": th_value} + for th_value in self.explaining_cfg.threshold.value.topk + for th_type in all_threshold_type + if "topk" in th_type + ] all_threshold = th_hard + th_topk - elif isinstance(threshold_type,str): + elif isinstance(threshold_type, str): if threshold_type in all_threshold_type: - if 'topk' in threshold_type: - all_threshold = [{ - "threshold_type": threshold_type, - "value": threshold_value, - } for threshold_value in self.explaining_cfg.threshold.value.topk] - elif threshold_type == 'hard': - all_threshold = [{ - "threshold_type": threshold_type, - "value": threshold_value, - } for threshold_value in self.explaining_cfg.threshold.value.hard] - elif isinstance(threshold_type,list): + if "topk" in threshold_type: + all_threshold = [ + { + "threshold_type": threshold_type, + "value": threshold_value, + } + for threshold_value in self.explaining_cfg.threshold.value.topk + ] + elif threshold_type == "hard": + all_threshold = [ + { + "threshold_type": threshold_type, + "value": threshold_value, + } + for threshold_value in self.explaining_cfg.threshold.value.hard + ] + elif isinstance(threshold_type, list): all_threshold = [] for tf_type in threshold_type: - if 'topk' in th_type: - all_threshold.expend([{ + if "topk" in th_type: + all_threshold.expend( + [ + { "threshold_type": threshold_type, "value": threshold_value, - } for threshold_value in self.explaining_cfg.threshold.value.topk]) - elif th_type == 'hard': - all_threshold.expend([{ + } + for threshold_value in self.explaining_cfg.threshold.value.topk + ] + ) + elif th_type == "hard": + all_threshold.expend( + [ + { "threshold_type": threshold_type, "value": threshold_value, - } for threshold_value in self.explaining_cfg.threshold.value.hard]) + } + for threshold_value in self.explaining_cfg.threshold.value.hard + ] + ) elif threshold_type is None: all_threshold = [] self.thresholds_configs = all_threshold - def set_explainer_threshold_config(self,threshold_config): + def set_explainer_threshold_config(self, threshold_config): self.explainer.threshold_config = ThresholdConfig.cast(threshold_config) def load_graphstat(self): self.graphstat = GraphStat() - def get_explanation_(self,item:Data,path:str): + def get_explanation(self, item: Data, path: str): if is_exists(path): if self.explaining_cfg.explainer.force: - explanation = get_explanation(self.explainer, item) + explanation = _get_explanation(self.explainer, item) else: - explanation = load_explanation(path) + explanation = _load_explanation(path) else: - explanation = get_explanation(explainer, item) - save_explanation(explanation,path) + explanation = _get_explanation(self.explainer, item) + _save_explanation(explanation, path) + explanation = explanation.to(cfg.accelerator) + return explanation - -class Explaining(object): - def __init__(self,outline:ExplainingOutline): - self.outline = outline - - def run(self): - pass - - def explain(self): - item, index = self.get_item() - not_none = item is None or index is None - whœ - - while - + def get_stat(self, item: Data, path: str): + if self.graphstat is None: + self.load_graphstat() + if is_exists(path): + pass + else: + if item.num_nodes <= 500: + stat = self.graphstat(item) + write_json(stat, path) diff --git a/explaining_framework/utils/explanation/io.py b/explaining_framework/utils/explanation/io.py index efd3a84..d532493 100644 --- a/explaining_framework/utils/explanation/io.py +++ b/explaining_framework/utils/explanation/io.py @@ -7,16 +7,18 @@ from torch_geometric.data import Data from torch_geometric.explain.explanation import Explanation -def get_explanation(explainer, item): +def _get_explanation(explainer, item): explanation = explainer( x=item.x, edge_index=item.edge_index, index=int(item.y), target=item.y, ) - # TODO return None if pas bien plutot - assert explanation_verification(explanation) - return explanation + if not explanation_verification(explanation): + # WARNING + LOG + return None + else: + return explanation def is_empty_graph(data: Data) -> bool: @@ -55,7 +57,7 @@ def explanation_verification(exp: Explanation) -> bool: return is_good -def save_explanation(exp: Explanation, path: str) -> None: +def _save_explanation(exp: Explanation, path: str) -> None: data = copy.copy(exp).to_dict() for k, v in data.items(): if isinstance(v, torch.Tensor): @@ -65,7 +67,7 @@ def save_explanation(exp: Explanation, path: str) -> None: json.dump(data, f) -def load_explanation(path: str) -> Explanation: +def _load_explanation(path: str) -> Explanation: with open(path, "r") as f: data = json.load(f) for k, v in data.items(): @@ -77,12 +79,3 @@ def load_explanation(path: str) -> Explanation: return Explanation.from_dict(data) -def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation: - exp = copy.copy(exp) - data = exp.to_dict() - for k, v in data.items(): - if "_mask" in k and isinstance(v, torch.FloatTensor): - norm = torch.norm(input=data[k], p=p, dim=None).item() - if norm.item() > 0: - data[k] = data[k] / norm - return exp diff --git a/main.py b/main.py index 75d9f71..ee5f984 100644 --- a/main.py +++ b/main.py @@ -27,99 +27,33 @@ from explaining_framework.utils.io import (is_exists, obj_config_to_str, # inference, time, force, -def get_pred(explainer, explanation): - pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[ - 0 - ] - setattr(explanation, "pred", pred) - data = explanation.to_dict() - if not data.get("node_mask") is None or not data.get("edge_mask") is None: - pred_masked = explainer.get_masked_prediction( - x=explanation.x, - edge_index=explanation.edge_index, - node_mask=data.get("node_mask"), - edge_mask=data.get("edge_mask"), - )[0] - setattr(explanation, "pred_exp", pred_masked) - - if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file) - auto_select_device() # Load components - dataset = outline.dataset - model = outline.model.to(cfg.accelerator) - model = model.eval() - model_info = outline.model_info - metrics = outline.metrics - explaining_algorithm = outline.explaining_algorithm - attacks = outline.attacks - explainer_cfg = outline.explainer_cfg - model_signature = outline.model_signature # RAJOUTER INDEXES - # Set seed - seed_everything(explaining_cfg.seed) - # Global path - global_path = os.path.join(explaining_cfg.out_dir, model_signature) + global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm)) makedirs(global_path) write_yaml(cfg, os.path.join(global_path, "config.yaml")) write_json(model_info, os.path.join(global_path, "info.json")) - # SET RUN DIR - global_path = os.path.join( - global_path, - explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm), - ) makedirs(global_path) - write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) - write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) - # SET EXPLAIN_DIR + write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) + write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) - global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) + global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm)) makedirs(global_path) # SET UP EXPLAINER - explainer = Explainer( - model=model, - algorithm=explaining_algorithm, - explainer_config=dict( - explanation_type=explaining_cfg.explanation_type, - node_mask_type="object", - edge_mask_type="object", - ), - model_config=dict( - mode="regression", - task_level=cfg.dataset.task, - return_type=explaining_cfg.model_config.return_type, - ), - ) - # CHERGER SUR LE GPU DIRECT - if not explaining_cfg.dataset.specific_items is None: - indexes = explaining_cfg.dataset.specific_items - else: - indexes = range(len(dataset)) # Save explaining configuration - for index, item in zip(indexes, dataset): - item = item.to(cfg.accelerator) - save_raw_path = os.path.join(global_path, "raw") - makedirs(save_raw_path) + item,index = outline.get_item() + while not(item is None or index is None): + raw_path = os.path.join(global_path, "raw") + makedirs(raw_path) explanation_path = os.path.join(save_raw_path, f"{index}.json") - if is_exists(explanation_path): - if explaining_cfg.explainer.force: - explanation = get_explanation(explainer, item) - else: - explanation = load_explanation(explanation_path) - else: - explanation = get_explanation(explainer, item) - - explanation = explanation.to(cfg.accelerator) - get_pred(explainer=explainer, explanation=explanation) - - save_explanation(explanation, explanation_path) for apply_relu in [True, False]: for apply_absolute in [True, False]: adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)