Skip to content

Commit

Permalink
Save VAEVAE latent as compressed npz
Browse files Browse the repository at this point in the history
Also mask lower bits of VAE's latent, to save disk space.
  • Loading branch information
jakobnissen committed Jul 18, 2024
1 parent 8401be0 commit 7612296
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cli_vamb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ jobs:
cat outdir_taxometer/log.txt
- name: Run k-means reclustering
run: |
vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npy --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000
vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000
ls -la outdir_recluster
cat outdir_recluster/log.txt
4 changes: 2 additions & 2 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,8 @@ def run_vaevae(
latent_both = vae.VAEJoint.encode(dataloader_joint)
logger.info(f"{latent_both.shape} embedding shape")

LATENT_PATH = vamb_options.out_dir.joinpath("vaevae_latent.npy")
np.save(LATENT_PATH, latent_both)
latent_path = vamb_options.out_dir.joinpath("vaevae_latent.npz")
vamb.vambtools.write_npz(latent_path, latent_both)

# Cluster, save tsv file
cluster_and_write_files(
Expand Down
3 changes: 3 additions & 0 deletions vamb/semisupervised_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.utils.data import DataLoader as _DataLoader
from torch.utils.data.dataset import TensorDataset as _TensorDataset
import vamb.encode as _encode
from vamb.vambtools import mask_lower_bits

if _torch.__version__ < "0.4":
raise ImportError("PyTorch version must be 0.4 or newer")
Expand Down Expand Up @@ -354,6 +355,7 @@ def encode(self, data_loader):
row += len(mu)

assert row == length
mask_lower_bits(latent, 12)
return latent

def trainmodel(
Expand Down Expand Up @@ -700,6 +702,7 @@ def encode(self, data_loader):
row += len(mu)

assert row == length
mask_lower_bits(latent, 12)
return latent


Expand Down

0 comments on commit 7612296

Please sign in to comment.