Adding new features
This commit is contained in:
parent
26fa51e2de
commit
e2d47af072
@ -159,6 +159,9 @@ class ExplainingOutline(object):
|
||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||
)
|
||||
self.dataset = create_dataset()
|
||||
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
||||
ind = self.explaining_cfg.dataset.specific_items
|
||||
self.dataset = self.dataset[ind : ind + 1]
|
||||
|
||||
def load_explainer(self):
|
||||
self.load_explainer_cfg()
|
||||
@ -201,11 +204,11 @@ class ExplainingOutline(object):
|
||||
self.load_explaining_cfg()
|
||||
|
||||
if self.explaining_cfg.metrics.type == "all":
|
||||
if self.explaining_cfg.dataset.name == 'BASHAPES':
|
||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||
|
||||
|
||||
def load_attack(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
|
Loading…
Reference in New Issue
Block a user