diff --git a/main.py b/main.py index a8be4a9..00957d1 100644 --- a/main.py +++ b/main.py @@ -29,8 +29,8 @@ if __name__ == "__main__": pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) - item, index = outline.get_item() - while not (item is None or index is None): + for item, index in zip(outline.dataset, outline.indexes): + item = item.to(outline.cfg.accelerator) for attack in outline.attacks: attack_path = os.path.join( outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) @@ -43,11 +43,9 @@ if __name__ == "__main__": if data_attack is None: continue - item, index = outline.get_item() - outline.reload_dataloader() - item, index = outline.get_item() - while not (item is None or index is None): + for item, index in zip(outline.dataset, outline.indexes): + item = item.to(outline.cfg.accelerator) for attack in outline.attacks: attack_path_ = os.path.join( outline.explainer_path, @@ -103,9 +101,6 @@ if __name__ == "__main__": out_metric = outline.get_metric( metric=metric, item=exp_masked, path=metric_path ) - - item, index = outline.get_item() - with open(os.path.join(outline.out_dir, "done"), "w") as f: f.write("")