Skip to content

Commit

Permalink
Mask out low bits of input float arrays
Browse files Browse the repository at this point in the history
Float32 has 23 bits of precision, which is more than necessary for Vamb.
We can mask out the lower bits, probably with no loss of performance.
This helps reducing the disk space usage of Vamb's output
  • Loading branch information
jakobnissen committed Apr 16, 2024
1 parent 8a9d061 commit 4fbc6c0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
1 change: 1 addition & 0 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def encode(self, data_loader) -> _np.ndarray:
row += len(mu)

assert row == length
_vambtools.mask_lower_bits(latent, 12)
return latent

def save(self, filehandle):
Expand Down
2 changes: 2 additions & 0 deletions vamb/parsebam.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def from_files(
comp_metadata.identifiers if verify_refhash else None,
comp_metadata.mask,
)
vambtools.mask_lower_bits(matrix, 12)
return cls(matrix, [str(p) for p in paths], minid, refhash)
# Else, we load it in chunks, then assemble afterwards
else:
Expand Down Expand Up @@ -183,6 +184,7 @@ def chunkwise_loading(
matrix = _np.empty((mask.sum(), len(paths)), dtype=_np.float32)
for filename, (chunkstart, chunkstop) in zip(filenames, chunks):
matrix[:, chunkstart:chunkstop] = vambtools.read_npz(filename)
vambtools.mask_lower_bits(matrix, 12)

shutil.rmtree(cache_directory)

Expand Down
1 change: 1 addition & 0 deletions vamb/parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def from_file(
# Convert rest of contigs
Composition._convert(raw, projected)
tnfs_arr = projected.take()
_vambtools.mask_lower_bits(tnfs_arr, 12)

# Don't use reshape since it creates a new array object with shared memory
tnfs_arr.shape = (len(tnfs_arr) // 103, 103)
Expand Down
9 changes: 9 additions & 0 deletions vamb/vambtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,15 @@ def torch_inplace_maskarray(array, mask):
return array


def mask_lower_bits(floats: _np.ndarray, bits: int) -> None:
if bits < 0 or bits > 23:
raise ValueError("Must mask between 0 and 23 bits")

mask = ~_np.uint32(2**bits - 1)
u = floats.view(_np.uint32)
u &= mask


class Reader:
"""Use this instead of `open` to open files which are either plain text,
gzipped, bzip2'd or zipped with LZMA.
Expand Down

0 comments on commit 4fbc6c0

Please sign in to comment.