From dd0406aaf469907186c8262eb42a8be2eceb0b5c Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Thu, 24 Nov 2022 11:02:48 +0100 Subject: [PATCH] Fix GPU clustering A variable which did not support GPU operations was errantly on GPU when cuda=True. Move to CPU. --- vamb/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vamb/cluster.py b/vamb/cluster.py index b06f5fb0..9ea8e560 100644 --- a/vamb/cluster.py +++ b/vamb/cluster.py @@ -349,7 +349,7 @@ def _smaller_indices(tensor, kept_mask, threshold, cuda): # If it's on GPU, we remove the already clustered points at this step. if cuda: - return _torch.nonzero((tensor <= threshold) & kept_mask).flatten() + return _torch.nonzero((tensor <= threshold) & kept_mask).flatten().cpu() else: arr = tensor.numpy() indices = (arr <= threshold).nonzero()[0]