Adding suitable iterator

This commit is contained in:
araison 2023-01-11 16:56:11 +01:00
parent bf559657d8
commit db04fbfaeb
1 changed files with 4 additions and 9 deletions

13
main.py
View File

@ -29,8 +29,8 @@ if __name__ == "__main__":
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
item, index = outline.get_item() for item, index in zip(outline.dataset, outline.indexes):
while not (item is None or index is None): item = item.to(outline.cfg.accelerator)
for attack in outline.attacks: for attack in outline.attacks:
attack_path = os.path.join( attack_path = os.path.join(
outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
@ -43,11 +43,9 @@ if __name__ == "__main__":
if data_attack is None: if data_attack is None:
continue continue
item, index = outline.get_item()
outline.reload_dataloader() outline.reload_dataloader()
item, index = outline.get_item() for item, index in zip(outline.dataset, outline.indexes):
while not (item is None or index is None): item = item.to(outline.cfg.accelerator)
for attack in outline.attacks: for attack in outline.attacks:
attack_path_ = os.path.join( attack_path_ = os.path.join(
outline.explainer_path, outline.explainer_path,
@ -103,9 +101,6 @@ if __name__ == "__main__":
out_metric = outline.get_metric( out_metric = outline.get_metric(
metric=metric, item=exp_masked, path=metric_path metric=metric, item=exp_masked, path=metric_path
) )
item, index = outline.get_item()
with open(os.path.join(outline.out_dir, "done"), "w") as f: with open(os.path.join(outline.out_dir, "done"), "w") as f:
f.write("") f.write("")