diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index af70e19..45d502d 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -159,6 +159,9 @@ class ExplainingOutline(object): f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." ) self.dataset = create_dataset() + if isinstance(self.explaining_cfg.dataset.specific_items, int): + ind = self.explaining_cfg.dataset.specific_items + self.dataset = self.dataset[ind : ind + 1] def load_explainer(self): self.load_explainer_cfg() @@ -201,11 +204,11 @@ class ExplainingOutline(object): self.load_explaining_cfg() if self.explaining_cfg.metrics.type == "all": - if self.explaining_cfg.dataset.name == 'BASHAPES': + if self.explaining_cfg.dataset.name == "BASHAPES": all_acc_metrics = [Accuracy(name) for name in all_accuracy] all_fid_metrics = [Fidelity(name) for name in all_fidelity] all_spa_metrics = [Sparsity(name) for name in all_sparsity] - + def load_attack(self): if self.cfg is None: self.load_cfg()