Skip to content

Commit

Permalink
try fix preds gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Nov 30, 2023
1 parent 8432792 commit 77312d1
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(self, output_dir, write_interval):
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of its respective rank
predictions = [(x.cpu(), y.cpu()) for (x, y) in predictions]
print("len(predictions):", len(predictions))
print("predictions[0][0].shape:", predictions[0][0].shape)
torch.save(
predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt")
)
Expand Down

0 comments on commit 77312d1

Please sign in to comment.