Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask out low bits of input float arrays #301

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def encode(self, data_loader) -> _np.ndarray:
row += len(mu)

assert row == length
_vambtools.mask_lower_bits(latent, 12)
return latent

def save(self, filehandle):
Expand Down
2 changes: 2 additions & 0 deletions vamb/parsebam.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def from_files(
comp_metadata.identifiers if verify_refhash else None,
comp_metadata.mask,
)
vambtools.mask_lower_bits(matrix, 12)
return cls(matrix, [str(p) for p in paths], minid, refhash)
# Else, we load it in chunks, then assemble afterwards
else:
Expand Down Expand Up @@ -183,6 +184,7 @@ def chunkwise_loading(
matrix = _np.empty((mask.sum(), len(paths)), dtype=_np.float32)
for filename, (chunkstart, chunkstop) in zip(filenames, chunks):
matrix[:, chunkstart:chunkstop] = vambtools.read_npz(filename)
vambtools.mask_lower_bits(matrix, 12)

shutil.rmtree(cache_directory)

Expand Down
1 change: 1 addition & 0 deletions vamb/parsecontigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def from_file(
# Convert rest of contigs
Composition._convert(raw, projected)
tnfs_arr = projected.take()
_vambtools.mask_lower_bits(tnfs_arr, 12)

# Don't use reshape since it creates a new array object with shared memory
tnfs_arr.shape = (len(tnfs_arr) // 103, 103)
Expand Down
9 changes: 9 additions & 0 deletions vamb/vambtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,15 @@ def torch_inplace_maskarray(array, mask):
return array


def mask_lower_bits(floats: _np.ndarray, bits: int) -> None:
if bits < 0 or bits > 23:
raise ValueError("Must mask between 0 and 23 bits")

mask = ~_np.uint32(2**bits - 1)
u = floats.view(_np.uint32)
u &= mask


class Reader:
"""Use this instead of `open` to open files which are either plain text,
gzipped, bzip2'd or zipped with LZMA.
Expand Down
Loading