From 4fbc6c01c697177cbbde1b1119b7c64a68acd792 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Tue, 16 Apr 2024 13:40:55 +0200 Subject: [PATCH] Mask out low bits of input float arrays Float32 has 23 bits of precision, which is more than necessary for Vamb. We can mask out the lower bits, probably with no loss of performance. This helps reducing the disk space usage of Vamb's output --- vamb/encode.py | 1 + vamb/parsebam.py | 2 ++ vamb/parsecontigs.py | 1 + vamb/vambtools.py | 9 +++++++++ 4 files changed, 13 insertions(+) diff --git a/vamb/encode.py b/vamb/encode.py index 5c2e3c11..21554993 100644 --- a/vamb/encode.py +++ b/vamb/encode.py @@ -466,6 +466,7 @@ def encode(self, data_loader) -> _np.ndarray: row += len(mu) assert row == length + _vambtools.mask_lower_bits(latent, 12) return latent def save(self, filehandle): diff --git a/vamb/parsebam.py b/vamb/parsebam.py index afa113f2..bd8f28fb 100644 --- a/vamb/parsebam.py +++ b/vamb/parsebam.py @@ -129,6 +129,7 @@ def from_files( comp_metadata.identifiers if verify_refhash else None, comp_metadata.mask, ) + vambtools.mask_lower_bits(matrix, 12) return cls(matrix, [str(p) for p in paths], minid, refhash) # Else, we load it in chunks, then assemble afterwards else: @@ -183,6 +184,7 @@ def chunkwise_loading( matrix = _np.empty((mask.sum(), len(paths)), dtype=_np.float32) for filename, (chunkstart, chunkstop) in zip(filenames, chunks): matrix[:, chunkstart:chunkstop] = vambtools.read_npz(filename) + vambtools.mask_lower_bits(matrix, 12) shutil.rmtree(cache_directory) diff --git a/vamb/parsecontigs.py b/vamb/parsecontigs.py index 5d127c81..fffef9b5 100644 --- a/vamb/parsecontigs.py +++ b/vamb/parsecontigs.py @@ -207,6 +207,7 @@ def from_file( # Convert rest of contigs Composition._convert(raw, projected) tnfs_arr = projected.take() + _vambtools.mask_lower_bits(tnfs_arr, 12) # Don't use reshape since it creates a new array object with shared memory tnfs_arr.shape = (len(tnfs_arr) // 103, 103) diff --git a/vamb/vambtools.py b/vamb/vambtools.py index 046c93e0..00e9f01a 100644 --- a/vamb/vambtools.py +++ b/vamb/vambtools.py @@ -294,6 +294,15 @@ def torch_inplace_maskarray(array, mask): return array +def mask_lower_bits(floats: _np.ndarray, bits: int) -> None: + if bits < 0 or bits > 23: + raise ValueError("Must mask between 0 and 23 bits") + + mask = ~_np.uint32(2**bits - 1) + u = floats.view(_np.uint32) + u &= mask + + class Reader: """Use this instead of `open` to open files which are either plain text, gzipped, bzip2'd or zipped with LZMA.