From db04fbfaeb18e8e9ce266eeb3b41cb7c0f6cf83f Mon Sep 17 00:00:00 2001 From: araison Date: Wed, 11 Jan 2023 16:56:11 +0100 Subject: [PATCH] Adding suitable iterator --- main.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index a8be4a9..00957d1 100644 --- a/main.py +++ b/main.py @@ -29,8 +29,8 @@ if __name__ == "__main__": pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) - item, index = outline.get_item() - while not (item is None or index is None): + for item, index in zip(outline.dataset, outline.indexes): + item = item.to(outline.cfg.accelerator) for attack in outline.attacks: attack_path = os.path.join( outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) @@ -43,11 +43,9 @@ if __name__ == "__main__": if data_attack is None: continue - item, index = outline.get_item() - outline.reload_dataloader() - item, index = outline.get_item() - while not (item is None or index is None): + for item, index in zip(outline.dataset, outline.indexes): + item = item.to(outline.cfg.accelerator) for attack in outline.attacks: attack_path_ = os.path.join( outline.explainer_path, @@ -103,9 +101,6 @@ if __name__ == "__main__": out_metric = outline.get_metric( metric=metric, item=exp_masked, path=metric_path ) - - item, index = outline.get_item() - with open(os.path.join(outline.out_dir, "done"), "w") as f: f.write("")