From 7612296536053326c259eab11b7d306bb761fcac Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 17 Jul 2024 08:03:05 +0200 Subject: [PATCH] Save VAEVAE latent as compressed npz Also mask lower bits of VAE's latent, to save disk space. --- .github/workflows/cli_vamb.yml | 2 +- vamb/__main__.py | 4 ++-- vamb/semisupervised_encode.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cli_vamb.yml b/.github/workflows/cli_vamb.yml index 12146499..1f74ae2a 100644 --- a/.github/workflows/cli_vamb.yml +++ b/.github/workflows/cli_vamb.yml @@ -61,6 +61,6 @@ jobs: cat outdir_taxometer/log.txt - name: Run k-means reclustering run: | - vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npy --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000 + vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --hmmout_path markers_mock.hmmout --algorithm kmeans --minfasta 200000 ls -la outdir_recluster cat outdir_recluster/log.txt diff --git a/vamb/__main__.py b/vamb/__main__.py index 066e89d3..7c515873 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -1288,8 +1288,8 @@ def run_vaevae( latent_both = vae.VAEJoint.encode(dataloader_joint) logger.info(f"{latent_both.shape} embedding shape") - LATENT_PATH = vamb_options.out_dir.joinpath("vaevae_latent.npy") - np.save(LATENT_PATH, latent_both) + latent_path = vamb_options.out_dir.joinpath("vaevae_latent.npz") + vamb.vambtools.write_npz(latent_path, latent_both) # Cluster, save tsv file cluster_and_write_files( diff --git a/vamb/semisupervised_encode.py b/vamb/semisupervised_encode.py index ac81e58a..858faf19 100644 --- a/vamb/semisupervised_encode.py +++ b/vamb/semisupervised_encode.py @@ -16,6 +16,7 @@ from torch.utils.data import DataLoader as _DataLoader from torch.utils.data.dataset import TensorDataset as _TensorDataset import vamb.encode as _encode +from vamb.vambtools import mask_lower_bits if _torch.__version__ < "0.4": raise ImportError("PyTorch version must be 0.4 or newer") @@ -354,6 +355,7 @@ def encode(self, data_loader): row += len(mu) assert row == length + mask_lower_bits(latent, 12) return latent def trainmodel( @@ -700,6 +702,7 @@ def encode(self, data_loader): row += len(mu) assert row == length + mask_lower_bits(latent, 12) return latent