Skip to content

Commit

Permalink
ENH: Support new checkpointing formatin baseline code
Browse files Browse the repository at this point in the history
  • Loading branch information
millanp95 committed Oct 17, 2024
1 parent 1e80d54 commit 843320f
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 311 deletions.
6 changes: 1 addition & 5 deletions barcodebert/bzsl/feature_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from .utils import (
extract_clean_barcode_list,
extract_clean_barcode_list_for_aligned,
extract_dna_features,
)
from .utils import extract_clean_barcode_list, extract_clean_barcode_list_for_aligned, extract_dna_features

__all__ = ["extract_clean_barcode_list", "extract_clean_barcode_list_for_aligned", "extract_dna_features"]
6 changes: 1 addition & 5 deletions barcodebert/bzsl/genus_species/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
import numpy as np
import torch

from barcodebert.bzsl.genus_species.bayesian_classifier import (
BayesianClassifier,
apply_pca,
calculate_priors,
)
from barcodebert.bzsl.genus_species.bayesian_classifier import BayesianClassifier, apply_pca, calculate_priors
from barcodebert.bzsl.genus_species.dataset import get_data_splits, load_data


Expand Down
8 changes: 1 addition & 7 deletions barcodebert/bzsl/models/dnabert/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@

from tokenizers.implementations import BaseTokenizer

from .file_utils import (
cached_path,
hf_bucket_url,
is_remote_url,
is_tf_available,
is_torch_available,
)
from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available

if is_tf_available():
import tensorflow as tf
Expand Down
6 changes: 1 addition & 5 deletions barcodebert/bzsl/surrogate_species/bayesian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from scipy.spatial.distance import cdist
from scipy.special import gammaln

from barcodebert.bzsl.surrogate_species.utils import (
DataLoader,
apply_pca,
perf_calc_acc,
)
from barcodebert.bzsl.surrogate_species.utils import DataLoader, apply_pca, perf_calc_acc


class Model:
Expand Down
187 changes: 67 additions & 120 deletions barcodebert/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Datasets.
"""

import os
from itertools import product

import numpy as np
Expand Down Expand Up @@ -40,16 +41,35 @@ def __call__(self, dna_sequence, offset=0) -> tuple[list, list]:
return tokens, att_mask


class DnaBertBPETokenizer(object):
def __init__(self, padding=False, max_tokenized_len=128):
class BPETokenizer(object):
def __init__(self, padding=False, max_tokenized_len=128, bpe_path=None):
self.padding = padding
self.max_tokenized_len = max_tokenized_len
self.bpe = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

assert os.path.isdir(bpe_path), f"The bpe path does not exist: {bpe_path}"

self.bpe = AutoTokenizer.from_pretrained(bpe_path)

# root_folder = os.path.dirname(__file__)
# if bpe_type == "dnabert":
# # self.bpe = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
# bpe_folder = os.path.join(root_folder, "bpe_tokenizers", "bpe_dnabert2")
# assert os.path.isdir(bpe_folder), f"Directory does not exist: {bpe_folder}"
# self.bpe = AutoTokenizer.from_pretrained(f"{bpe_folder}/")
# elif bpe_type.__contains__("barcode"):
# length = bpe_type.split("_")[-1]
# bpe_folder = os.path.join(root_folder, "bpe_tokenizers", f"bpe_barcode_{length}")
# assert os.path.isdir(bpe_folder), f"Directory does not exist: {bpe_folder}"
# self.bpe = AutoTokenizer.from_pretrained(bpe_folder)
# else:
# raise NotImplementedError(f"bpe_type {bpe_type} is not supported.")

def __call__(self, dna_sequence, offset=0) -> tuple[list, list]:
x = dna_sequence[offset:]
tokens = self.bpe(x, padding=True, return_tensors="pt")["input_ids"]
tokens[tokens == 0] = 1
tokens[tokens == 2] = 3
tokens[tokens == 1] = 2
tokens[tokens == 0] = 1 # all the UNK + CLS have token of 1

tokens = tokens[0].tolist()

Expand All @@ -61,6 +81,7 @@ def __call__(self, dna_sequence, offset=0) -> tuple[list, list]:
tokens = tokens + [1] * (self.max_tokenized_len - len(tokens))

att_mask = torch.tensor(att_mask, dtype=torch.int32)
tokens = torch.tensor(tokens, dtype=torch.int64)
return tokens, att_mask


Expand All @@ -69,10 +90,11 @@ def __init__(
self,
file_path,
k_mer=4,
stride=4,
stride=None,
max_len=256,
randomize_offset=False,
tokenizer="kmer",
bpe_path=None,
tokenize_n_nucleotide=False,
dataset_format="CANADA-1.5M",
):
Expand All @@ -85,16 +107,15 @@ def __init__(
if dataset_format not in ["CANADA-1.5M", "BIOSCAN-5M"]:
raise NotImplementedError(f"Dataset {dataset_format} not supported.")

# Vocabulary
base_pairs = "ACGT"
self.special_tokens = ["[MASK]", "[UNK]"] # ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[EOS]", "[UNK]"]
UNK_TOKEN = "[UNK]"

if tokenize_n_nucleotide:
# Encode kmers which contain N differently depending on where it is
base_pairs += "N"

if tokenizer == "kmer":
# Vocabulary
base_pairs = "ACGT"
self.special_tokens = ["[MASK]", "[UNK]"] # ["[MASK]", "[CLS]", "[SEP]", "[PAD]", "[EOS]", "[UNK]"]
UNK_TOKEN = "[UNK]"

if tokenize_n_nucleotide:
# Encode kmers which contain N differently depending on where it is
base_pairs += "N"
kmers = ["".join(kmer) for kmer in product(base_pairs, repeat=self.k_mer)]

# Separate between good (idx < 4**k) and bad k-mers (idx > 4**k) for prediction
Expand All @@ -116,13 +137,14 @@ def __init__(
self.tokenizer = KmerTokenizer(
self.k_mer, self.vocab, stride=self.stride, padding=True, max_len=self.max_len
)
elif tokenizer == "DnaBertBPE":
self.tokenizer = DnaBertBPETokenizer(padding=True, max_tokenized_len=self.max_len)
elif tokenizer == "bpe":
self.tokenizer = BPETokenizer(padding=True, max_tokenized_len=self.max_len, bpe_path=bpe_path)
self.vocab_size = self.tokenizer.bpe.vocab_size
else:
raise ValueError(f'Tokenizer "{tokenizer}" not recognized.')
df = pd.read_csv(file_path, sep="\t" if file_path.endswith(".tsv") else ",", keep_default_na=False)
self.barcodes = df["nucleotides"].to_list()

if dataset_format == "CANADA-1.5M":
self.labels, self.label_set = pd.factorize(df["species_name"], sort=True)
self.num_labels = len(self.label_set)
Expand All @@ -144,117 +166,42 @@ def __getitem__(self, idx):
return processed_barcode, label, att_mask


def single_inference(model, tokenizer, barcode):
with torch.no_grad():
x, att_mask = tokenizer(barcode)

x = x.unsqueeze(0).to(model.device)
att_mask = att_mask.unsqueeze(0).to(model.device)
x = model(x, att_mask).hidden_states[-1]
# updated mean pooling to account for the attention mask and padding tokens
# sum the embeddings of the tokens (excluding padding tokens)
x = (x * att_mask.unsqueeze(-1)).sum(1) # (batch_size, hidden_size)
# sum the attention mask (number of tokens in the sequence without considering the padding tokens)
sum_mask = att_mask.sum(1, keepdim=True)
# calculate the mean embeddings
x /= sum_mask # (batch_size, hidden_size)
return x


def representations_from_df(df, target_level, model, tokenizer):
def representations_from_df(df, target_level, model, tokenizer, dataset_name):

orders = df["order_name"].to_numpy()

_label_set, y = np.unique(df[target_level], return_inverse=True)
if dataset_name == "CANADA-1.5M":
_label_set, y = np.unique(df[target_level], return_inverse=True)
elif dataset_name == "BIOSCAN-5M":
# _label_set = np.unique(df[target_level])
y = df[target_level]
else:
raise NotImplementedError("Dataset format is not supported. Must be one of CANADA-1.5M or BIOSCAN-5M")

dna_embeddings = []
for barcode in df["nucleotides"]:
x = single_inference(model, tokenizer, barcode)
dna_embeddings.append(x.cpu().numpy())

with torch.no_grad():
for barcode in df["nucleotides"]:
x, att_mask = tokenizer(barcode)

x = x.unsqueeze(0).to(model.device)
att_mask = att_mask.unsqueeze(0).to(model.device)
x = model(x, att_mask).hidden_states[-1]
# previous mean pooling
# x = x.mean(1)
# dna_embeddings.append(x.cpu().numpy())

# updated mean pooling to account for the attention mask and padding tokens
# sum the embeddings of the tokens (excluding padding tokens)
sum_embeddings = (x * att_mask.unsqueeze(-1)).sum(1) # (batch_size, hidden_size)
# sum the attention mask (number of tokens in the sequence without considering the padding tokens)
sum_mask = att_mask.sum(1, keepdim=True)
# calculate the mean embeddings
mean_embeddings = sum_embeddings / sum_mask # (batch_size, hidden_size)

dna_embeddings.append(mean_embeddings.cpu().numpy())

print(f"There are {len(df)} points in the dataset")
latent = np.array(dna_embeddings)
latent = np.squeeze(latent, 1)
print(latent.shape)
return latent, y, orders


def inference_from_df(df, model, tokenizer):

assert "processid" in df.columns # Check that processid column is present in your dataframe
assert "nucleotides" in df.columns # Check that nucleotide column is present in your dataframe

dna_embeddings = {}

for _i, row in df.iterrows():
barcode = row["nucleotides"]
id = row["processid"]

x = single_inference(model, tokenizer, barcode)

dna_embeddings[id] = x.cpu().numpy()

return dna_embeddings


def check_sequence(header, seq):
"""
Adapted from VAMB: https://github.com/RasmussenLab/vamb
Check that there're no invalid characters or bad format
in the file.
Note: The GAPS ('-') that are introduced from alignment
are considered valid characters.
"""

if len(header) > 0 and (header[0] in (">", "#") or header[0].isspace()):
raise ValueError("Bad character in sequence header")
if "\t" in header:
raise ValueError("tab included in header")

basemask = bytearray.maketrans(b"acgtuUswkmyrbdhvnSWKMYRBDHV-", b"ACGTTTNNNNNNNNNNNNNNNNNNNNNN")

masked = seq.translate(basemask, b" \t\n\r")
stripped = masked.translate(None, b"ACGTN")
if len(stripped) > 0:
bad_character = chr(stripped[0])
msg = "Invalid DNA byte in sequence {}: '{}'"
raise ValueError(msg.format(header, bad_character))
return masked


def inference_from_fasta(fname, model, tokenizer):

dna_embeddings = {}
lines = []
seq_id = ""

for line in open(fname, "rb"):
if line.startswith(b"#"):
pass

elif line.startswith(b">"):
if seq_id != "":
seq = bytearray().join(lines)

# Check entry is valid
seq = check_sequence(seq_id, seq)

# Compute embedding
x = single_inference(model, tokenizer, seq.decode())

seq_id = line[1:-1].decode() # Modify this according to your labels.
lines = []
dna_embeddings[seq_id] = x.cpu().numpy()
seq_id = line[1:-1].decode()
else:
lines += [line.strip()]

seq = bytearray().join(lines)
seq = check_sequence(seq_id, seq)
# Compute embedding
x = single_inference(model, tokenizer, seq.decode())
dna_embeddings[seq_id] = x.cpu().numpy()

return dna_embeddings
Loading

0 comments on commit 843320f

Please sign in to comment.