From 843320ffaf4d7a1122afe4b113fe8e95ffc89509 Mon Sep 17 00:00:00 2001 From: Pablo Date: Wed, 16 Oct 2024 20:52:37 -0400 Subject: [PATCH] ENH: Support new checkpointing formatin baseline code --- .../bzsl/feature_extraction/__init__.py | 6 +- barcodebert/bzsl/genus_species/main.py | 6 +- .../bzsl/models/dnabert/tokenization_utils.py | 8 +- .../bzsl/surrogate_species/bayesian_model.py | 6 +- barcodebert/datasets.py | 187 +++++++----------- barcodebert/knn_probing.py | 64 +++--- barcodebert/pretraining.py | 30 +-- baselines/datasets.py | 144 ++++++++++---- baselines/embedders.py | 103 +++++++--- baselines/finetuning.py | 103 ++++++---- baselines/io.py | 30 ++- baselines/knn_probing.py | 23 ++- baselines/models/dnabert2.py | 8 +- 13 files changed, 407 insertions(+), 311 deletions(-) diff --git a/barcodebert/bzsl/feature_extraction/__init__.py b/barcodebert/bzsl/feature_extraction/__init__.py index f297058..930f3de 100644 --- a/barcodebert/bzsl/feature_extraction/__init__.py +++ b/barcodebert/bzsl/feature_extraction/__init__.py @@ -1,7 +1,3 @@ -from .utils import ( - extract_clean_barcode_list, - extract_clean_barcode_list_for_aligned, - extract_dna_features, -) +from .utils import extract_clean_barcode_list, extract_clean_barcode_list_for_aligned, extract_dna_features __all__ = ["extract_clean_barcode_list", "extract_clean_barcode_list_for_aligned", "extract_dna_features"] diff --git a/barcodebert/bzsl/genus_species/main.py b/barcodebert/bzsl/genus_species/main.py index efe85a8..1437e7d 100644 --- a/barcodebert/bzsl/genus_species/main.py +++ b/barcodebert/bzsl/genus_species/main.py @@ -6,11 +6,7 @@ import numpy as np import torch -from barcodebert.bzsl.genus_species.bayesian_classifier import ( - BayesianClassifier, - apply_pca, - calculate_priors, -) +from barcodebert.bzsl.genus_species.bayesian_classifier import BayesianClassifier, apply_pca, calculate_priors from barcodebert.bzsl.genus_species.dataset import get_data_splits, load_data diff --git a/barcodebert/bzsl/models/dnabert/tokenization_utils.py b/barcodebert/bzsl/models/dnabert/tokenization_utils.py index 7acc66d..1158027 100644 --- a/barcodebert/bzsl/models/dnabert/tokenization_utils.py +++ b/barcodebert/bzsl/models/dnabert/tokenization_utils.py @@ -26,13 +26,7 @@ from tokenizers.implementations import BaseTokenizer -from .file_utils import ( - cached_path, - hf_bucket_url, - is_remote_url, - is_tf_available, - is_torch_available, -) +from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf diff --git a/barcodebert/bzsl/surrogate_species/bayesian_model.py b/barcodebert/bzsl/surrogate_species/bayesian_model.py index 3890709..d4bf815 100644 --- a/barcodebert/bzsl/surrogate_species/bayesian_model.py +++ b/barcodebert/bzsl/surrogate_species/bayesian_model.py @@ -6,11 +6,7 @@ from scipy.spatial.distance import cdist from scipy.special import gammaln -from barcodebert.bzsl.surrogate_species.utils import ( - DataLoader, - apply_pca, - perf_calc_acc, -) +from barcodebert.bzsl.surrogate_species.utils import DataLoader, apply_pca, perf_calc_acc class Model: diff --git a/barcodebert/datasets.py b/barcodebert/datasets.py index 000e5d9..4217059 100644 --- a/barcodebert/datasets.py +++ b/barcodebert/datasets.py @@ -2,6 +2,7 @@ Datasets. """ +import os from itertools import product import numpy as np @@ -40,16 +41,35 @@ def __call__(self, dna_sequence, offset=0) -> tuple[list, list]: return tokens, att_mask -class DnaBertBPETokenizer(object): - def __init__(self, padding=False, max_tokenized_len=128): +class BPETokenizer(object): + def __init__(self, padding=False, max_tokenized_len=128, bpe_path=None): self.padding = padding self.max_tokenized_len = max_tokenized_len - self.bpe = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) + + assert os.path.isdir(bpe_path), f"The bpe path does not exist: {bpe_path}" + + self.bpe = AutoTokenizer.from_pretrained(bpe_path) + + # root_folder = os.path.dirname(__file__) + # if bpe_type == "dnabert": + # # self.bpe = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True) + # bpe_folder = os.path.join(root_folder, "bpe_tokenizers", "bpe_dnabert2") + # assert os.path.isdir(bpe_folder), f"Directory does not exist: {bpe_folder}" + # self.bpe = AutoTokenizer.from_pretrained(f"{bpe_folder}/") + # elif bpe_type.__contains__("barcode"): + # length = bpe_type.split("_")[-1] + # bpe_folder = os.path.join(root_folder, "bpe_tokenizers", f"bpe_barcode_{length}") + # assert os.path.isdir(bpe_folder), f"Directory does not exist: {bpe_folder}" + # self.bpe = AutoTokenizer.from_pretrained(bpe_folder) + # else: + # raise NotImplementedError(f"bpe_type {bpe_type} is not supported.") def __call__(self, dna_sequence, offset=0) -> tuple[list, list]: x = dna_sequence[offset:] tokens = self.bpe(x, padding=True, return_tensors="pt")["input_ids"] - tokens[tokens == 0] = 1 + tokens[tokens == 2] = 3 + tokens[tokens == 1] = 2 + tokens[tokens == 0] = 1 # all the UNK + CLS have token of 1 tokens = tokens[0].tolist() @@ -61,6 +81,7 @@ def __call__(self, dna_sequence, offset=0) -> tuple[list, list]: tokens = tokens + [1] * (self.max_tokenized_len - len(tokens)) att_mask = torch.tensor(att_mask, dtype=torch.int32) + tokens = torch.tensor(tokens, dtype=torch.int64) return tokens, att_mask @@ -69,10 +90,11 @@ def __init__( self, file_path, k_mer=4, - stride=4, + stride=None, max_len=256, randomize_offset=False, tokenizer="kmer", + bpe_path=None, tokenize_n_nucleotide=False, dataset_format="CANADA-1.5M", ): @@ -85,16 +107,15 @@ def __init__( if dataset_format not in ["CANADA-1.5M", "BIOSCAN-5M"]: raise NotImplementedError(f"Dataset {dataset_format} not supported.") - # Vocabulary - base_pairs = "ACGT" - self.special_tokens = ["[MASK]", "[UNK]"] # ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[EOS]", "[UNK]"] - UNK_TOKEN = "[UNK]" - - if tokenize_n_nucleotide: - # Encode kmers which contain N differently depending on where it is - base_pairs += "N" - if tokenizer == "kmer": + # Vocabulary + base_pairs = "ACGT" + self.special_tokens = ["[MASK]", "[UNK]"] # ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[EOS]", "[UNK]"] + UNK_TOKEN = "[UNK]" + + if tokenize_n_nucleotide: + # Encode kmers which contain N differently depending on where it is + base_pairs += "N" kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=self.k_mer)] # Separate between good (idx < 4**k) and bad k-mers (idx > 4**k) for prediction @@ -116,13 +137,14 @@ def __init__( self.tokenizer = KmerTokenizer( self.k_mer, self.vocab, stride=self.stride, padding=True, max_len=self.max_len ) - elif tokenizer == "DnaBertBPE": - self.tokenizer = DnaBertBPETokenizer(padding=True, max_tokenized_len=self.max_len) + elif tokenizer == "bpe": + self.tokenizer = BPETokenizer(padding=True, max_tokenized_len=self.max_len, bpe_path=bpe_path) self.vocab_size = self.tokenizer.bpe.vocab_size else: raise ValueError(f'Tokenizer "{tokenizer}" not recognized.') df = pd.read_csv(file_path, sep="\t" if file_path.endswith(".tsv") else ",", keep_default_na=False) self.barcodes = df["nucleotides"].to_list() + if dataset_format == "CANADA-1.5M": self.labels, self.label_set = pd.factorize(df["species_name"], sort=True) self.num_labels = len(self.label_set) @@ -144,117 +166,42 @@ def __getitem__(self, idx): return processed_barcode, label, att_mask -def single_inference(model, tokenizer, barcode): - with torch.no_grad(): - x, att_mask = tokenizer(barcode) - - x = x.unsqueeze(0).to(model.device) - att_mask = att_mask.unsqueeze(0).to(model.device) - x = model(x, att_mask).hidden_states[-1] - # updated mean pooling to account for the attention mask and padding tokens - # sum the embeddings of the tokens (excluding padding tokens) - x = (x * att_mask.unsqueeze(-1)).sum(1) # (batch_size, hidden_size) - # sum the attention mask (number of tokens in the sequence without considering the padding tokens) - sum_mask = att_mask.sum(1, keepdim=True) - # calculate the mean embeddings - x /= sum_mask # (batch_size, hidden_size) - return x - - -def representations_from_df(df, target_level, model, tokenizer): +def representations_from_df(df, target_level, model, tokenizer, dataset_name): orders = df["order_name"].to_numpy() - - _label_set, y = np.unique(df[target_level], return_inverse=True) + if dataset_name == "CANADA-1.5M": + _label_set, y = np.unique(df[target_level], return_inverse=True) + elif dataset_name == "BIOSCAN-5M": + # _label_set = np.unique(df[target_level]) + y = df[target_level] + else: + raise NotImplementedError("Dataset format is not supported. Must be one of CANADA-1.5M or BIOSCAN-5M") dna_embeddings = [] - for barcode in df["nucleotides"]: - x = single_inference(model, tokenizer, barcode) - dna_embeddings.append(x.cpu().numpy()) + + with torch.no_grad(): + for barcode in df["nucleotides"]: + x, att_mask = tokenizer(barcode) + + x = x.unsqueeze(0).to(model.device) + att_mask = att_mask.unsqueeze(0).to(model.device) + x = model(x, att_mask).hidden_states[-1] + # previous mean pooling + # x = x.mean(1) + # dna_embeddings.append(x.cpu().numpy()) + + # updated mean pooling to account for the attention mask and padding tokens + # sum the embeddings of the tokens (excluding padding tokens) + sum_embeddings = (x * att_mask.unsqueeze(-1)).sum(1) # (batch_size, hidden_size) + # sum the attention mask (number of tokens in the sequence without considering the padding tokens) + sum_mask = att_mask.sum(1, keepdim=True) + # calculate the mean embeddings + mean_embeddings = sum_embeddings / sum_mask # (batch_size, hidden_size) + + dna_embeddings.append(mean_embeddings.cpu().numpy()) print(f"There are {len(df)} points in the dataset") latent = np.array(dna_embeddings) latent = np.squeeze(latent, 1) print(latent.shape) return latent, y, orders - - -def inference_from_df(df, model, tokenizer): - - assert "processid" in df.columns # Check that processid column is present in your dataframe - assert "nucleotides" in df.columns # Check that nucleotide column is present in your dataframe - - dna_embeddings = {} - - for _i, row in df.iterrows(): - barcode = row["nucleotides"] - id = row["processid"] - - x = single_inference(model, tokenizer, barcode) - - dna_embeddings[id] = x.cpu().numpy() - - return dna_embeddings - - -def check_sequence(header, seq): - """ - Adapted from VAMB: https://github.com/RasmussenLab/vamb - - Check that there're no invalid characters or bad format - in the file. - - Note: The GAPS ('-') that are introduced from alignment - are considered valid characters. - """ - - if len(header) > 0 and (header[0] in (">", "#") or header[0].isspace()): - raise ValueError("Bad character in sequence header") - if "\t" in header: - raise ValueError("tab included in header") - - basemask = bytearray.maketrans(b"acgtuUswkmyrbdhvnSWKMYRBDHV-", b"ACGTTTNNNNNNNNNNNNNNNNNNNNNN") - - masked = seq.translate(basemask, b" \t\n\r") - stripped = masked.translate(None, b"ACGTN") - if len(stripped) > 0: - bad_character = chr(stripped[0]) - msg = "Invalid DNA byte in sequence {}: '{}'" - raise ValueError(msg.format(header, bad_character)) - return masked - - -def inference_from_fasta(fname, model, tokenizer): - - dna_embeddings = {} - lines = [] - seq_id = "" - - for line in open(fname, "rb"): - if line.startswith(b"#"): - pass - - elif line.startswith(b">"): - if seq_id != "": - seq = bytearray().join(lines) - - # Check entry is valid - seq = check_sequence(seq_id, seq) - - # Compute embedding - x = single_inference(model, tokenizer, seq.decode()) - - seq_id = line[1:-1].decode() # Modify this according to your labels. - lines = [] - dna_embeddings[seq_id] = x.cpu().numpy() - seq_id = line[1:-1].decode() - else: - lines += [line.strip()] - - seq = bytearray().join(lines) - seq = check_sequence(seq_id, seq) - # Compute embedding - x = single_inference(model, tokenizer, seq.decode()) - dna_embeddings[seq_id] = x.cpu().numpy() - - return dna_embeddings diff --git a/barcodebert/knn_probing.py b/barcodebert/knn_probing.py index f31f442..e4802e9 100755 --- a/barcodebert/knn_probing.py +++ b/barcodebert/knn_probing.py @@ -14,8 +14,8 @@ from torchtext.vocab import vocab as build_vocab_from_dict from barcodebert import utils -from barcodebert.datasets import KmerTokenizer, representations_from_df -from barcodebert.io import get_project_root, load_pretrained_model +from barcodebert.datasets import BPETokenizer, KmerTokenizer, representations_from_df +from barcodebert.io import load_pretrained_model def run(config): @@ -63,8 +63,11 @@ def run(config): "stride", "max_len", "tokenizer", - "use_unk_token", + "bpe_path", "tokenize_n_nucleotide", + "predict_n_nucleotide", + "pretrain_levenshtein", + "levenshtein_vectorized", "n_layers", "n_heads", "dataset_name", @@ -73,7 +76,7 @@ def run(config): for key in keys_to_reuse: if not hasattr(config, key) or getattr(config, key) == getattr(pre_checkpoint["config"], key): pass - elif getattr(config, key) == default_kwargs[key]: + elif getattr(config, key) is None or getattr(config, key) == default_kwargs[key]: print( f" Overriding default config value {key}={getattr(config, key)}" f" with {getattr(pre_checkpoint['config'], key)} from pretained checkpoint." @@ -90,35 +93,36 @@ def run(config): # DATASET ================================================================= - base_pairs = "ACGT" - # specials = ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[UNK]"] - specials = ["[MASK]", "[UNK]"] - UNK_TOKEN = "[UNK]" + if config.tokenizer == "kmer": + base_pairs = "ACGT" + # specials = ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[UNK]"] + specials = ["[MASK]", "[UNK]"] + UNK_TOKEN = "[UNK]" - if config.tokenize_n_nucleotide: - # Encode kmers which contain N differently depending on where it is - base_pairs += "N" + if config.tokenize_n_nucleotide: + # Encode kmers which contain N differently depending on where it is + base_pairs += "N" - kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=config.k_mer)] + kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=config.k_mer)] - if config.tokenize_n_nucleotide: - prediction_kmers = [] - other_kmers = [] - for kmer in kmers: - if "N" in kmer: - other_kmers.append(kmer) - else: - prediction_kmers.append(kmer) + if config.tokenize_n_nucleotide: + prediction_kmers = [] + other_kmers = [] + for kmer in kmers: + if "N" in kmer: + other_kmers.append(kmer) + else: + prediction_kmers.append(kmer) - kmers = prediction_kmers + other_kmers + kmers = prediction_kmers + other_kmers - kmer_dict = dict.fromkeys(kmers, 1) - vocab = build_vocab_from_dict(kmer_dict, specials=specials) - vocab.set_default_index(vocab[UNK_TOKEN]) - tokenizer = KmerTokenizer(config.k_mer, vocab, stride=config.k_mer, padding=True, max_len=config.max_len) + kmer_dict = dict.fromkeys(kmers, 1) + vocab = build_vocab_from_dict(kmer_dict, specials=specials) + vocab.set_default_index(vocab[UNK_TOKEN]) + tokenizer = KmerTokenizer(config.k_mer, vocab, stride=config.k_mer, padding=True, max_len=config.max_len) - if config.data_dir is None: - config.data_dir = os.path.join(get_project_root(), "data") + elif config.tokenizer == "bpe": + tokenizer = BPETokenizer(padding=True, max_tokenized_len=config.max_len, bpe_path=config.bpe_path) df_train = pd.read_csv(os.path.join(config.data_dir, "supervised_train.csv")) df_test = pd.read_csv(os.path.join(config.data_dir, "unseen.csv")) @@ -140,9 +144,11 @@ def run(config): t_start_embed = time.time() # Generate emebddings for the training and test sets print("Generating embeddings for test set", flush=True) - X_unseen, y_unseen, orders = representations_from_df(df_test, config.target_level, model, tokenizer) + X_unseen, y_unseen, orders = representations_from_df( + df_test, config.target_level, model, tokenizer, config.dataset_name + ) print("Generating embeddings for train set", flush=True) - X, y, train_orders = representations_from_df(df_train, config.target_level, model, tokenizer) + X, y, train_orders = representations_from_df(df_train, config.target_level, model, tokenizer, config.dataset_name) timing_stats["embed"] = time.time() - t_start_embed c = 0 diff --git a/barcodebert/pretraining.py b/barcodebert/pretraining.py index 2f5b88b..b70116e 100755 --- a/barcodebert/pretraining.py +++ b/barcodebert/pretraining.py @@ -48,6 +48,7 @@ def run(config): utils.setup_slurm_distributed() config.world_size = int(os.environ.get("WORLD_SIZE", 1)) config.distributed = utils.check_is_distributed() + if config.world_size > 1 and not config.distributed: raise EnvironmentError( f"WORLD_SIZE is {config.world_size}, but not all other required" @@ -426,7 +427,7 @@ def print_pass(*args, **kwargs): total_step=total_step, n_samples_seen=n_samples_seen, distance_table=distance_table, - n_special_tokens=len(dataset_train.special_tokens) + n_special_tokens=len(dataset_train.special_tokens), ) t_end_train = time.time() @@ -639,7 +640,7 @@ def train_one_epoch( # t_start_masking = time.time() # Create a mask for allowed tokens i.e. that excludes all special tokens [, ] - special_tokens_mask = (sequences > (n_special_tokens - 1)) + special_tokens_mask = sequences > (n_special_tokens - 1) if config.tokenize_n_nucleotide: # Either exlude the last token [N..N] if config.predict_n_nucleotide == True @@ -647,7 +648,6 @@ def train_one_epoch( # is greater than 4**k special_tokens_mask &= sequences < (n_special_tokens + 4**config.k_mer - 1) - special_tokens_mask = special_tokens_mask.to(device) masked_input = sequences.clone() random_mask = torch.rand(sequences.shape, device=device) @@ -663,13 +663,13 @@ def train_one_epoch( ct_forward = torch.cuda.Event(enable_timing=True) ct_forward.record() # Perform the forward pass through the model - out = model(masked_input , attention_mask=att_mask) + out = model(masked_input, attention_mask=att_mask) targets = sequences - n_special_tokens * (sequences > (n_special_tokens - 1)) # Measure loss loss = criterion( - out.logits.view(-1, 4**config.k_mer)[special_tokens_mask.view(-1)], - targets.view(-1)[special_tokens_mask.view(-1)], - ) + out.logits.view(-1, 4**config.k_mer)[special_tokens_mask.view(-1)], + targets.view(-1)[special_tokens_mask.view(-1)], + ) # Backward pass ------------------------------------------------------- # Reset gradients @@ -894,7 +894,7 @@ def evaluate( # Build the masking on the fly ------------------------------------ # t_start_masking = time.time() # Create a mask for allowed tokens i.e. that excludes all special tokens [, ] - special_tokens_mask = (sequences > (n_special_tokens - 1)) + special_tokens_mask = sequences > (n_special_tokens - 1) if config.tokenize_n_nucleotide: # Either exlude the last token [N..N] if config.predict_n_nucleotide == True @@ -902,7 +902,6 @@ def evaluate( # is greater than 4**k special_tokens_mask &= sequences < (n_special_tokens + 4**config.k_mer - 1) - special_tokens_mask = special_tokens_mask.to(device) masked_input = sequences.clone() random_mask = torch.rand(masked_input.shape, generator=rng, device=device) @@ -911,13 +910,13 @@ def evaluate( masked_input[input_maskout] = 0 # Forward pass ---------------------------------------------------- - out = model(masked_input , attention_mask=att_mask) + out = model(masked_input, attention_mask=att_mask) # Measure loss targets = sequences - n_special_tokens * (sequences > (n_special_tokens - 1)) loss = criterion( - out.logits.view(-1, 4**config.k_mer)[special_tokens_mask.view(-1)], - targets.view(-1)[special_tokens_mask.view(-1)], - ) + out.logits.view(-1, 4**config.k_mer)[special_tokens_mask.view(-1)], + targets.view(-1)[special_tokens_mask.view(-1)], + ) # Metrics --------------------------------------------------------- # Update the total loss for the epoch @@ -1264,8 +1263,9 @@ def cli(): r"""Command-line interface for model training.""" parser = get_parser() config = parser.parse_args() - #If stride value is ommited., then it is equal to k_mer - if not config.stride: config.stride = config.k_mer + # If stride value is ommited., then it is equal to k_mer + if not config.stride: + config.stride = config.k_mer # Handle disable_wandb overriding log_wandb and forcing it to be disabled. if config.disable_wandb: config.log_wandb = False diff --git a/baselines/datasets.py b/baselines/datasets.py index ac66948..63a6586 100644 --- a/baselines/datasets.py +++ b/baselines/datasets.py @@ -4,27 +4,37 @@ import os import pickle +from itertools import product import numpy as np import pandas as pd import torch from torch.utils.data import Dataset +from torchtext.vocab import build_vocab_from_iterator from tqdm.auto import tqdm +from transformers import AutoTokenizer class DNADataset(Dataset): - def __init__(self, file_path, embedder, randomize_offset=False, max_length=660): + def __init__(self, file_path, embedder, randomize_offset=False, max_length=660, dataset_format="CANADA-1.5M"): self.randomize_offset = randomize_offset - df = pd.read_csv(file_path, sep="\t" if file_path.endswith(".tsv") else ",", keep_default_na=False) + df = pd.read_csv(file_path, sep="\t" if file_path.endswith(".tsv") else ",") self.barcodes = df["nucleotides"].to_list() - self.ids = df["species_index"].to_list() # ideally, this should be process id self.tokenizer = embedder.tokenizer self.backbone_name = embedder.name self.max_len = max_length + self.dataset_format = dataset_format - self.num_labels = 22_622 + if dataset_format == "CANADA-1.5M": + self.labels, self.label_set = pd.factorize(df["species_name"], sort=True) + self.ids = df["species_name"].to_list() # ideally, this should be process id + self.num_labels = len(self.label_set) + else: + self.num_labels = 22_622 + self.ids = df["species_index"].to_list() # ideally, this should be process id + self.labels = self.ids def __len__(self): return len(self.barcodes) @@ -37,31 +47,79 @@ def __getitem__(self, idx): x = self.barcodes[idx] if len(x) > self.max_len: - x = x[: self.max_len] - else: - x = x + "N" * (self.max_len - len(x)) + x = x[: self.max_len] # Truncate, but do not force the max_len, let the model tokenize handle it. if self.backbone_name == "BarcodeBERT": - processed_barcode, _ = self.tokenizer(x, offset=offset) + processed_barcode, att_mask = self.tokenizer(x, offset=offset) + + elif self.backbone_name == "Hyena_DNA": + encoding_info = self.tokenizer( + x, + return_tensors="pt", + return_attention_mask=True, + return_token_type_ids=False, + max_length=self.max_len, + padding="max_length", + truncation=True, + add_special_tokens=False, + ) + + processed_barcode = encoding_info["input_ids"] + # print(processed_barcode.shape) + att_mask = encoding_info["attention_mask"] + + elif self.backbone_name == "DNABERT": + k = 6 + kmer = [x[i : i + k] for i in range(len(x) + 1 - k)] + kmers = " ".join(kmer) + encoding_info = self.tokenizer.encode_plus( + kmers, + sentence_b=None, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=512, + return_attention_mask=True, + truncation=True, + ) + processed_barcode = encoding_info["input_ids"] + # print(processed_barcode.shape) + att_mask = encoding_info["attention_mask"] + else: - processed_barcode = self.tokenizer( + encoding_info = self.tokenizer( x, return_tensors="pt", return_attention_mask=True, return_token_type_ids=False, max_length=512, + add_special_tokens=False, padding="max_length", - )["input_ids"].int() + truncation=True, + ) + + processed_barcode = encoding_info["input_ids"] + # print(processed_barcode.shape) + att_mask = encoding_info["attention_mask"] + + label = torch.tensor(self.labels[idx], dtype=torch.int64) - label = self.ids[idx] - return processed_barcode, label + return processed_barcode, label, att_mask -def representations_from_df(filename, embedder, batch_size=128): +def representations_from_df( + filename, + embedder, + batch_size=128, + save_embeddings=True, + dataset="BIOSCAN-5M", + embeddings_folder="/scratch/ssd004/scratch/pmillana/embeddings/embeddings", +): # create embeddings folder - if not os.path.isdir("embeddings"): - os.mkdir("embeddings") + if save_embeddings: + embeddings_path = f"{embeddings_folder}/{dataset}" + os.makedirs(embeddings_path, exist_ok=True) backbone = embedder.name @@ -71,7 +129,7 @@ def representations_from_df(filename, embedder, batch_size=128): print(f"Calculating embeddings for {backbone}") # create a folder for a specific backbone within embeddings - backbone_folder = os.path.join("embeddings", backbone) + backbone_folder = os.path.join(embeddings_path, backbone) if not os.path.isdir(backbone_folder): os.mkdir(backbone_folder) @@ -82,7 +140,7 @@ def representations_from_df(filename, embedder, batch_size=128): if os.path.exists(out_fname): print(f"We found the file {out_fname}. It seems that we have computed the embeddings ... \n") - print("Loading the embeddings from that file") + print(f"Loading the embeddings from that file") with open(out_fname, "rb") as handle: embeddings = pickle.load(handle) @@ -91,7 +149,9 @@ def representations_from_df(filename, embedder, batch_size=128): else: - dataset_val = DNADataset(file_path=filename, embedder=embedder, randomize_offset=False, max_length=660) + dataset_val = DNADataset( + file_path=filename, embedder=embedder, randomize_offset=False, max_length=660, dataset_format=dataset + ) dl_val_kwargs = { "batch_size": batch_size, @@ -105,13 +165,11 @@ def representations_from_df(filename, embedder, batch_size=128): embeddings_list = [] id_list = [] with torch.no_grad(): - for _batch_idx, (sequences, _id) in tqdm(enumerate(dataloader_val)): + for batch_idx, (sequences, _id, att_mask) in tqdm(enumerate(dataloader_val)): sequences = sequences.view(-1, sequences.shape[-1]).to(device) + att_mask = att_mask.view(-1, att_mask.shape[-1]).to(device) # print(sequences.shape) - att_mask = sequences != 1 - - # TODO: The first token is always [CLS] - n_embeddings = att_mask.sum(axis=1) + # att_mask = (sequences != 1) # print(n_embeddings.shape) @@ -122,28 +180,30 @@ def representations_from_df(filename, embedder, batch_size=128): elif backbone == "Hyena_DNA": out = embedder.model(sequences) - elif backbone in ["DNABERT-2", "DNABERT-S"]: + elif backbone in ["DNABERT", "DNABERT-2", "DNABERT-S"]: out = embedder.model(sequences)[0] elif backbone == "BarcodeBERT": - out = embedder.model(sequences).hidden_states[-1] - - if backbone != "BarcodeBERT": - # print(out.shape) - att_mask = att_mask.unsqueeze(2).expand(-1, -1, embedder.hidden_size) - # print(att_mask.shape) - out = out * att_mask - # print(out.shape) - out = out.sum(axis=1) - # print(out.shape) - out = torch.div(out.t(), n_embeddings) - # print(out.shape) - - # Move embeddings back to CPU and convert to numpy array - embeddings = out.t().cpu().numpy() - - else: - embeddings = out.mean(1).cpu().numpy() + out = embedder.model(sequences, att_mask).hidden_states[-1] + + # if backbone != "BarcodeBERT": + # print(out.shape) + + n_embeddings = att_mask.sum(axis=1) + # print(n_embeddings.shape) + + att_mask = att_mask.unsqueeze(2).expand(-1, -1, embedder.hidden_size) + # print(att_mask.shape) + + out = out * att_mask + # print(out.shape) + out = out.sum(axis=1) + # print(out.shape) + out = torch.div(out.t(), n_embeddings) + # print(out.shape) + + # Move embeddings back to CPU and convert to numpy array + embeddings = out.t().cpu().numpy() # Collect embeddings embeddings_list.append(embeddings) diff --git a/baselines/embedders.py b/baselines/embedders.py index 2bb2101..116392f 100644 --- a/baselines/embedders.py +++ b/baselines/embedders.py @@ -16,8 +16,7 @@ # Adapted from https://github.com/frederikkemarin/BEND/blob/main/bend/models/dnabert2_padding.py # Which was adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py -# Which was adapted from: -# https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py +# Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py import os @@ -28,7 +27,10 @@ import numpy as np import torch +from sklearn.preprocessing import LabelEncoder +from torch import nn from torchtext.vocab import build_vocab_from_iterator +from torchtext.vocab import vocab as build_vocab_from_dict from tqdm.auto import tqdm from transformers import ( AutoModel, @@ -37,6 +39,7 @@ BertConfig, BertModel, BertTokenizer, + BigBirdModel, logging, ) @@ -100,10 +103,12 @@ def __call__(self, sequence: str, *args, **kwargs): The embedding of the sequence. """ return self.embed([sequence], *args, disable_tqdm=True, **kwargs)[0] + return embeddings -# DNABERT https://doi.org/10.1093/bioinformatics/btab083 -# Download from https://github.com/jerryji1993/DNABERT +## +## DNABert https://doi.org/10.1093/bioinformatics/btab083 +## Download from https://github.com/jerryji1993/DNABERT class DNABertEmbedder(BaseEmbedder): @@ -116,8 +121,7 @@ def load_model(self, model_path: str = "../../external-models/DNABERT/", kmer: i ---------- model_path : str The path to the model directory. Defaults to "../../external-models/DNABERT/". - The DNABERT models need to be downloaded manually as indicated in the DNABERT repository at: - https://github.com/jerryji1993/DNABERT. + The DNABERT models need to be downloaded manually as indicated in the DNABERT repository at https://github.com/jerryji1993/DNABERT. kmer : int The kmer size of the model. Defaults to 6. @@ -129,8 +133,7 @@ def load_model(self, model_path: str = "../../external-models/DNABERT/", kmer: i if not os.path.exists(dnabert_path): print( - f"Path {dnabert_path} does not exists, check if the wrong path was given. \ - If not download from https://github.com/jerryji1993/DNABERT" + f"Path {dnabert_path} does not exists, check if the wrong path was given. If not download from https://github.com/jerryji1993/DNABERT" ) config = BertConfig.from_pretrained(dnabert_path) @@ -328,16 +331,17 @@ def embed( """ self.model.eval() + cls_tokens = [] embeddings = [] with torch.no_grad(): - for _n, s in enumerate(tqdm(sequences, disable=disable_tqdm)): + for n, s in enumerate(tqdm(sequences, disable=disable_tqdm)): # print('sequence', n) s_chunks = [ s[chunk : chunk + self.max_seq_len] for chunk in range(0, len(s), self.max_seq_len) ] # split into chunks embedded_seq = [] - for _n_chunk, chunk in enumerate(s_chunks): # embed each chunk + for n_chunk, chunk in enumerate(s_chunks): # embed each chunk tokens_ids = self.tokenizer(chunk, return_tensors="pt")["input_ids"].int().to(device) if len(tokens_ids[0]) > self.max_tokens: # too long to fit into the model split = torch.split(tokens_ids, self.max_tokens, dim=-1) @@ -365,8 +369,7 @@ def embed( if upsample_embeddings: outs = self._repeat_embedding_vectors(self.tokenizer.convert_ids_to_tokens(tokens_ids[0]), outs) embedded_seq.append(outs[:, 1:] if remove_special_tokens else outs) - # print('chunk', n_chunk, 'chunk length', len(chunk), 'tokens length', len(tokens_ids[0]), ... - # 'chunk embedded shape', outs.shape) + # print('chunk', n_chunk, 'chunk length', len(chunk), 'tokens length', len(tokens_ids[0]), 'chunk embedded shape', outs.shape) embeddings.append(np.concatenate(embedded_seq, axis=1)) return embeddings @@ -410,10 +413,8 @@ def load_model(self, model_path="pretrained_models/hyenadna/hyenadna-tiny-1k-seq ---------- model_path : str, optional Path to the model checkpoint. Defaults to 'pretrained_models/hyenadna/hyenadna-tiny-1k-seqlen'. - If the path does not exist, the model will be downloaded from HuggingFace. Rather than just - downloading the model, - HyenaDNA's `from_pretrained` method relies on cloning the HuggingFace-hosted repository, - and using git lfs to download the model. + If the path does not exist, the model will be downloaded from HuggingFace. Rather than just downloading the model, + HyenaDNA's `from_pretrained` method relies on cloning the HuggingFace-hosted repository, and using git lfs to download the model. This requires git lfs to be installed on your system, and will fail if it is not. @@ -432,9 +433,9 @@ def load_model(self, model_path="pretrained_models/hyenadna/hyenadna-tiny-1k-seq # all these settings are copied directly from huggingface.py # data settings: - # use_padding = True - # rc_aug = False # reverse complement augmentation - # add_eos = False # add end of sentence token + use_padding = True + rc_aug = False # reverse complement augmentation + add_eos = False # add end of sentence token # we need these for the decoder head, if using use_head = False @@ -501,6 +502,10 @@ def embed( List of embeddings. """ + # # prep model and forward + # model.to(device) + # with torch.inference_mode(): + embeddings = [] with torch.inference_mode(): for s in tqdm(sequences, disable=disable_tqdm): @@ -508,8 +513,9 @@ def embed( s[chunk : chunk + self.max_length] for chunk in range(0, len(s), self.max_length) ] # split into chunks embedded_chunks = [] - for _n_chunk, chunk in enumerate(chunks): - # Single embedding example + for n_chunk, chunk in enumerate(chunks): + #### Single embedding example #### + # create a sample 450k long, prepare # sequence = 'ACTG' * int(self.max_length/4) tok_seq = self.tokenizer(chunk) # adds CLS and SEP tokens @@ -792,7 +798,9 @@ class BarcodeBERTEmbedder(BaseEmbedder): Embed using the DNABERTS model https://arxiv.org/abs/2402.08777 """ - def load_model(self, checkpoint_path=None, from_paper=False, k_mer=8, n_heads=4, n_layers=4, **kwargs): + def load_model( + self, checkpoint_path=None, from_paper=False, k_mer=8, n_heads=4, n_layers=4, new_vocab=False, **kwargs + ): """ Load a pretrained model it can be downloaded or it can be from a checkpoint file. @@ -815,7 +823,7 @@ def load_model(self, checkpoint_path=None, from_paper=False, k_mer=8, n_heads=4, if not from_paper: model, ckpt = load_pretrained_model(checkpoint_path, device=device) else: - model = load_old_pretrained_model(checkpoint_path, k_mer, device=device) + model = load_old_pretrained_model(checkpoint_path, config, device=device) else: arch = f"{k_mer}_{n_heads}_{n_layers}" @@ -856,12 +864,47 @@ def load_model(self, checkpoint_path=None, from_paper=False, k_mer=8, n_heads=4, self.model.to(device) # tokenizer: - kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k_mer)) - if from_paper: - vocab = build_vocab_from_iterator(kmer_iter, specials=["", "", ""]) + if new_vocab: + if not ckpt: + raise NotImplementedError(f"New vocab requires an updated checkpoint structure") + else: + # Vocabulary + base_pairs = "ACGT" + special_tokens = ["[MASK]", "[UNK]"] # ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[EOS]", "[UNK]"] + UNK_TOKEN = "[UNK]" + k_mer = ckpt["config"].k_mer + tokenize_n_nucleotide = False # ckpt['config'].tokenize_n_nucleotide: + + if tokenize_n_nucleotide: + # Encode kmers which contain N differently depending on where it is + base_pairs += "N" + kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=k_mer)] + + # Separate between good (idx < 4**k) and bad k-mers (idx > 4**k) for prediction + if tokenize_n_nucleotide: + prediction_kmers = [] + other_kmers = [] + for kmer in kmers: + if "N" in kmer: + other_kmers.append(kmer) + else: + prediction_kmers.append(kmer) + + kmers = prediction_kmers + other_kmers + + kmer_dict = dict.fromkeys(kmers, 1) + vocab = build_vocab_from_dict(kmer_dict, specials=special_tokens) + vocab.set_default_index(vocab[UNK_TOKEN]) + vocab_size = len(vocab) + self.tokenizer = KmerTokenizer(k_mer, vocab, stride=ckpt["config"].stride, padding=True, max_len=660) + else: - vocab = build_vocab_from_iterator(kmer_iter, specials=["", ""]) - vocab.set_default_index(vocab[""]) # is necessary in the hard case + kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k_mer)) + if from_paper: + vocab = build_vocab_from_iterator(kmer_iter, specials=["", "", ""]) + else: + vocab = build_vocab_from_iterator(kmer_iter, specials=["", ""]) + vocab.set_default_index(vocab[""]) # is necessary in the hard case - tokenizer = KmerTokenizer(k_mer, vocab, stride=k_mer, padding=True, max_len=660) - self.tokenizer = tokenizer + tokenizer = KmerTokenizer(k_mer, vocab, stride=k_mer, padding=True, max_len=660) + self.tokenizer = tokenizer diff --git a/baselines/finetuning.py b/baselines/finetuning.py index 1825f32..d0133ef 100755 --- a/baselines/finetuning.py +++ b/baselines/finetuning.py @@ -20,6 +20,7 @@ sys.path.append(".") from barcodebert import utils +from barcodebert.io import safe_save_model from baselines.datasets import DNADataset from baselines.io import load_baseline_model @@ -39,23 +40,41 @@ def __init__(self, embedder, num_labels): self.hidden_size = embedder.hidden_size self.classifier = nn.Linear(self.hidden_size, self.num_labels) - def forward(self, input_ids=None, mask=None, labels=None): + def forward(self, sequences=None, mask=None, labels=None): # Getting the embeddings # call each model's wrapper if self.backbone == "NT": - out = self.base_model(input_ids, output_hidden_states=True)["hidden_states"][-1] + out = self.base_model(sequences, attention_mask=mask, output_hidden_states=True)["hidden_states"][-1] elif self.backbone == "Hyena_DNA": - out = self.base_model(input_ids) + out = self.base_model(sequences) - elif self.backbone in ["DNABERT-2", "DNABERT-S"]: - out = self.base_model(input_ids)[0] + elif self.backbone in ["DNABERT", "DNABERT-2", "DNABERT-S"]: + out = self.base_model(sequences, attention_mask=mask)[0] elif self.backbone == "BarcodeBERT": - out = self.base_model(input_ids).hidden_states[-1] + out = self.base_model(sequences, att_mask).hidden_states[-1] + + # if backbone != "BarcodeBERT": + # print(out.shape) + + n_embeddings = mask.sum(axis=1) + # print(n_embeddings.shape) + + att_mask = mask.unsqueeze(2).expand(-1, -1, self.hidden_size) + # print(att_mask.shape) + + out = out * att_mask + # print(out.shape) + out = out.sum(axis=1) + # print(out.shape) + out = torch.div(out.t(), n_embeddings) + # print(out.shape) + + # Transpose back GAP embeddings + GAP_embeddings = out.t() - GAP_embeddings = out.mean(1) # TODO: Swap between GAP and CLS # calculate losses logits = self.classifier(GAP_embeddings.view(-1, self.hidden_size)) loss = None @@ -101,12 +120,13 @@ def evaluate( y_pred_all = [] xent_all = [] - for sequences, y_true in dataloader: + for sequences, y_true, att_mask in dataloader: sequences = sequences.view(-1, sequences.shape[-1]).to(device) + att_mask = att_mask.view(-1, att_mask.shape[-1]).to(device) y_true = y_true.to(device) with torch.no_grad(): - logits = model(sequences, labels=y_true).logits + logits = model(sequences, labels=y_true, mask=att_mask).logits xent = F.cross_entropy(logits, y_true, reduction="none") y_pred = torch.argmax(logits, dim=-1) @@ -181,7 +201,8 @@ def run(config): # Setup for distributed training utils.setup_slurm_distributed() config.world_size = int(os.environ.get("WORLD_SIZE", 1)) - config.distributed = utils.check_is_distributed() + # config.distributed = utils.check_is_distributed() + config.distributed = False if config.world_size > 1 and not config.distributed: raise EnvironmentError( f"WORLD_SIZE is {config.world_size}, but not all other required" @@ -273,7 +294,7 @@ def print_pass(*args, **kwargs): # DATASET ================================================================= - if config.dataset_name not in ["CANADA_1.5M", "BIOSCAN-5M"]: + if config.dataset_name not in ["CANADA-1.5M", "BIOSCAN-5M"]: raise NotImplementedError(f"Dataset {config.dataset_name} not supported.") # Handle default stride dynamically set to equal k-mer size @@ -281,17 +302,22 @@ def print_pass(*args, **kwargs): config.stride = config.k_mer dataset_train = DNADataset( - file_path=os.path.join(config.data_dir, "supervised_train.csv"), embedder=embedder, randomize_offset=False + file_path=os.path.join(config.data_dir, "supervised_train.csv"), + embedder=embedder, + randomize_offset=False, + dataset_format=config.dataset_name, ) dataset_val = DNADataset( file_path=os.path.join(config.data_dir, "supervised_val.csv"), embedder=embedder, randomize_offset=False, + dataset_format=config.dataset_name, ) dataset_test = DNADataset( file_path=os.path.join(config.data_dir, "supervised_test.csv"), embedder=embedder, randomize_offset=False, + dataset_format=config.dataset_name, ) distinct_val_test = True @@ -442,7 +468,7 @@ def print_pass(*args, **kwargs): config.model_output_dir = os.path.join( config.models_dir, config.dataset_name, - f"{config.run_name}__{config.run_id}", + config.backbone, ) config.checkpoint_path = os.path.join(config.model_output_dir, "checkpoint_finetune.pt") if config.log_wandb and config.global_rank == 0: @@ -451,6 +477,7 @@ def print_pass(*args, **kwargs): if config.checkpoint_path is None: print("Model will not be saved.") else: + os.makedirs(config.model_output_dir, exist_ok=True) print(f"Model will be saved to '{config.checkpoint_path}'") # RESUME ================================================================== @@ -593,25 +620,24 @@ def print_pass(*args, **kwargs): # Save model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ t_start_save = time.time() - # if config.model_output_dir and (not config.distributed or config.global_rank == 0): - # safe_save_model( - # { - # "model": model, - # "optimizer": optimizer, - # "scheduler": scheduler, - # }, - # config.checkpoint_path, - # config=config, - # epoch=epoch, - # total_step=total_step, - # n_samples_seen=n_samples_seen, - # bert_config=pre_checkpoint["bert_config"], - # **best_stats, - # ) - # if config.save_best_model and best_stats["best_epoch"] == epoch: - # ckpt_path_best = os.path.join(config.model_output_dir, "best_finetune.pt") - # print(f"Copying model to {ckpt_path_best}") - # shutil.copyfile(config.checkpoint_path, ckpt_path_best) + if config.model_output_dir and (not config.distributed or config.global_rank == 0): + safe_save_model( + { + "model": model, + "optimizer": optimizer, + "scheduler": scheduler, + }, + config.checkpoint_path, + config=config, + epoch=epoch, + total_step=total_step, + n_samples_seen=n_samples_seen, + bert_config=None, + ) + if config.save_best_model and best_stats["best_epoch"] == epoch: + ckpt_path_best = os.path.join(config.model_output_dir, "best_finetune.pt") + print(f"Copying model to {ckpt_path_best}") + shutil.copyfile(config.checkpoint_path, ckpt_path_best) t_end_save = time.time() timing_stats["saving"] = t_end_save - t_start_save @@ -780,13 +806,14 @@ def train_one_epoch( t_end_batch = time.time() t_start_wandb = t_end_wandb = None - for batch_idx, (sequences, y_true) in enumerate(dataloader): + for batch_idx, (sequences, y_true, att_mask) in enumerate(dataloader): t_start_batch = time.time() batch_size_this_gpu = sequences.shape[0] # Move training inputs and targets to the GPU # sequences = sequences.to(device) sequences = sequences.view(-1, sequences.shape[-1]).to(device) + att_mask = att_mask.view(-1, sequences.shape[-1]).to(device) y_true = y_true.to(device) # Forward pass -------------------------------------------------------- @@ -795,7 +822,7 @@ def train_one_epoch( ct_forward = torch.cuda.Event(enable_timing=True) ct_forward.record() # Perform the forward pass through the model - out = model(sequences, labels=y_true) + out = model(sequences, labels=y_true, mask=att_mask) loss = out.loss # Backward pass ------------------------------------------------------- @@ -846,6 +873,7 @@ def train_one_epoch( print("loss.shape =", loss.shape) # Debugging intensifies print("sequences[0] =", sequences[0]) + print("att_mask[0] =", att_mask[0]) print("y_true[0] =", y_true[0]) print("y_pred[0] =", y_pred[0]) print("logits[0] =", out.logits[0]) @@ -982,6 +1010,13 @@ def get_parser(): help="Architecture of the Encoder one of [DNABERT-2, Hyena_DNA, DNABERT-S, \ BarcodeBERT, NT]", ) + + group.add_argument( + "--dataset_name", + default="BIOSCAN-5M", + type=str, + help="Dataset format %(default)s", + ) return parser diff --git a/baselines/io.py b/baselines/io.py index 726cf5a..f131056 100644 --- a/baselines/io.py +++ b/baselines/io.py @@ -2,16 +2,28 @@ Input/output utilities. """ +import os +from inspect import getsourcefile + import torch +from transformers import BertConfig, BertForMaskedLM, BertForTokenClassification from baselines.embedders import ( BarcodeBERTEmbedder, DNABert2Embedder, + DNABertEmbedder, DNABertSEmbedder, HyenaDNAEmbedder, NucleotideTransformerEmbedder, ) +# PACKAGE_DIR = os.path.dirname(os.path.abspath(getsourcefile(lambda: 0))) + + +# def get_project_root() -> str: +# return os.path.dirname(PACKAGE_DIR) + + device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) @@ -24,27 +36,33 @@ def load_baseline_model(backbone_name, *args, **kwargs): "DNABERT-2": DNABert2Embedder, "DNABERT-S": DNABertSEmbedder, "BarcodeBERT": BarcodeBERTEmbedder, + "DNABERT": DNABertEmbedder, } # Positional arguments as a list # Keyword arguments as a dictionary checkpoints = { - "NT": (["InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"], {}), - "Hyena_DNA": (["pretrained_models/hyenadna-tiny-1k-seqlen"], {}), - "DNABERT-2": (["zhihan1996/DNABERT-2-117M"], {}), - "DNABERT-S": (["zhihan1996/DNABERT-S"], {}), - "BarcodeBERT": ([None], {}), + "NT": (["InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"], kwargs), + "Hyena_DNA": ( + ["/h/pmillana/projects/BIOSCAN_5M_DNA_experiments/pretrained_models/hyenadna-tiny-1k-seqlen"], + kwargs, + ), + "DNABERT-2": (["zhihan1996/DNABERT-2-117M"], kwargs), + "DNABERT-S": (["zhihan1996/DNABERT-S"], kwargs), + "DNABERT": (["/scratch/ssd004/scratch/pmillana/checkpoints/dnabert/6-new-12w-0"], kwargs), + "BarcodeBERT": ([], kwargs), } out_dimensions = { "NT": 512, "Hyena_DNA": 128, "DNABERT-2": 768, + "DNABERT": 768, "DNABERT-S": 768, "BarcodeBERT": 768, } positional_args, keyword_args = checkpoints[backbone_name] - embedder = backbones[backbone_name](*positional_args, **kwargs) + embedder = backbones[backbone_name](*positional_args, **keyword_args) embedder.hidden_size = out_dimensions[backbone_name] return embedder diff --git a/baselines/knn_probing.py b/baselines/knn_probing.py index 0486b97..f62b40c 100755 --- a/baselines/knn_probing.py +++ b/baselines/knn_probing.py @@ -61,9 +61,14 @@ def run(config): # DATASET ================================================================= if config.taxon.lower() == "bin": - target_level = "bin_uri" + config.target_level = "bin_uri" else: - target_level = f"{config.taxon}_index" + if config.dataset_name == "CANADA-1.5M": + target_level = config.taxon + "_name" + elif config.dataset_name == "BIOSCAN-5M": + target_level = config.taxon + "_index" + else: + raise NotImplementedError("Dataset format is not supported. Must be one of CANADA-1.5M or BIOSCAN-5M") timing_stats["preamble"] = time.time() - t_start t_start_embed = time.time() @@ -80,13 +85,13 @@ def run(config): # Generate emebddings for the training and test sets print("Generating embeddings for test set", flush=True) - X_unseen = representations_from_df(test_filename, embedder, batch_size=128) - y_unseen = labels_from_df(test_filename, f"{config.taxon}_index", label_pipeline) + X_unseen = representations_from_df(test_filename, embedder, batch_size=128, dataset=config.dataset_name) + y_unseen = labels_from_df(test_filename, target_level, label_pipeline) print(X_unseen.shape, y_unseen.shape) print("Generating embeddings for train set", flush=True) - X = representations_from_df(train_filename, embedder, batch_size=128) - y = labels_from_df(train_filename, f"{config.taxon}_index", label_pipeline) + X = representations_from_df(train_filename, embedder, batch_size=128, dataset=config.dataset_name) + y = labels_from_df(train_filename, target_level, label_pipeline) print(X.shape, y.shape) timing_stats["embed"] = time.time() - t_start_embed @@ -223,6 +228,12 @@ def get_parser(): help="Architecture of the Encoder one of [DNABERT-2, HyenaDNA, DNABERT-S, \ BarcodeBERT, NT]", ) + group.add_argument( + "--dataset_name", + default="BIOSCAN-5M", + type=str, + help="Dataset format %(default)s", + ) # kNN args ---------------------------------------------------------------- group = parser.add_argument_group("kNN parameters") group.add_argument( diff --git a/baselines/models/dnabert2.py b/baselines/models/dnabert2.py index 4561bcf..b735cf3 100644 --- a/baselines/models/dnabert2.py +++ b/baselines/models/dnabert2.py @@ -19,13 +19,7 @@ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput from transformers.models.bert.modeling_bert import BertPreTrainedModel -from .dnabert2_padding import ( - index_first_axis, - index_put_first_axis, - pad_input, - unpad_input, - unpad_input_only, -) +from .dnabert2_padding import index_first_axis, index_put_first_axis, pad_input, unpad_input, unpad_input_only try: from .flash_attn_triton import flash_attn_qkvpacked_func