Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : remove notion of CLS token #11064

Merged
merged 1 commit into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ class Tokenizer:
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
CLS_ID = "tokenizer.ggml.cls_token_id"
MASK_ID = "tokenizer.ggml.mask_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
Expand Down Expand Up @@ -1837,7 +1836,6 @@ def get_type(val: Any) -> GGUFValueType:
KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID
KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID
KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
Expand Down
3 changes: 0 additions & 3 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,6 @@ def add_sep_token_id(self, id: int) -> None:
def add_pad_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.PAD_ID, id)

def add_cls_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.CLS_ID, id)

def add_mask_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.MASK_ID, id)

Expand Down
5 changes: 4 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,6 @@ extern "C" {
LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence
LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence
LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn
LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab); // classification
LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
Expand Down Expand Up @@ -973,6 +972,10 @@ extern "C" {
DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");

// CLS is equivalent to BOS
DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification
"use llama_vocab_bos instead");

//
// Tokenization
//
Expand Down
26 changes: 8 additions & 18 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,6 @@ struct llama_vocab::impl {
llama_token special_unk_id = 0;
llama_token special_sep_id = LLAMA_TOKEN_NULL;
llama_token special_pad_id = LLAMA_TOKEN_NULL;
llama_token special_cls_id = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
llama_token special_mask_id = LLAMA_TOKEN_NULL;

llama_token linefeed_id = 13;
Expand Down Expand Up @@ -1352,7 +1351,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_unk_id = LLAMA_TOKEN_NULL;
special_sep_id = LLAMA_TOKEN_NULL;
special_pad_id = LLAMA_TOKEN_NULL;
special_cls_id = LLAMA_TOKEN_NULL;
special_mask_id = LLAMA_TOKEN_NULL;
linefeed_id = LLAMA_TOKEN_NULL;

Expand All @@ -1374,18 +1372,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_unk_id = 0;
special_sep_id = LLAMA_TOKEN_NULL;
special_pad_id = LLAMA_TOKEN_NULL;
special_cls_id = LLAMA_TOKEN_NULL;
special_mask_id = LLAMA_TOKEN_NULL;
} else if (tokenizer_model == "bert") {
type = LLAMA_VOCAB_TYPE_WPM;

// default special tokens
special_bos_id = LLAMA_TOKEN_NULL;
special_bos_id = 101;
special_eos_id = LLAMA_TOKEN_NULL;
special_unk_id = 100;
special_sep_id = 102;
special_pad_id = 0;
special_cls_id = 101;
special_mask_id = 103;
} else if (tokenizer_model == "gpt2") {
type = LLAMA_VOCAB_TYPE_BPE;
Expand Down Expand Up @@ -1420,7 +1416,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_unk_id = LLAMA_TOKEN_NULL;
special_sep_id = LLAMA_TOKEN_NULL;
special_pad_id = LLAMA_TOKEN_NULL;
special_cls_id = LLAMA_TOKEN_NULL;
special_mask_id = LLAMA_TOKEN_NULL;
} else if (tokenizer_model == "t5") {
type = LLAMA_VOCAB_TYPE_UGM;
Expand All @@ -1431,7 +1426,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_unk_id = 2;
special_sep_id = LLAMA_TOKEN_NULL;
special_pad_id = 0;
special_cls_id = LLAMA_TOKEN_NULL;
special_mask_id = LLAMA_TOKEN_NULL;

const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
Expand Down Expand Up @@ -1712,7 +1706,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
{ LLM_KV_TOKENIZER_UNK_ID, special_unk_id },
{ LLM_KV_TOKENIZER_SEP_ID, special_sep_id },
{ LLM_KV_TOKENIZER_PAD_ID, special_pad_id },
{ LLM_KV_TOKENIZER_CLS_ID, special_cls_id },
{ LLM_KV_TOKENIZER_MASK_ID, special_mask_id },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id },
Expand Down Expand Up @@ -2406,8 +2399,8 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
case LLAMA_VOCAB_TYPE_WPM:
{
if (add_special) {
GGML_ASSERT(special_cls_id != LLAMA_TOKEN_NULL);
output.push_back(special_cls_id);
GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
output.push_back(special_bos_id);
}

llm_tokenizer_wpm_session session(vocab);
Expand Down Expand Up @@ -2700,7 +2693,6 @@ void llama_vocab::impl::print_info() const {
if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token[special_unk_id].text.c_str() ); }
if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token[special_sep_id].text.c_str() ); }
if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token[special_pad_id].text.c_str() ); }
if (special_cls_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, special_cls_id, id_to_token[special_cls_id].text.c_str() ); }
if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token[special_mask_id].text.c_str() ); }

if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token[linefeed_id].text.c_str() ); }
Expand Down Expand Up @@ -2834,7 +2826,7 @@ llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
}

llama_token llama_vocab::token_bos() const {
return pimpl->type != LLAMA_VOCAB_TYPE_WPM ? pimpl->special_bos_id : pimpl->special_cls_id;
return pimpl->special_bos_id;
}

llama_token llama_vocab::token_eos() const {
Expand All @@ -2853,10 +2845,6 @@ llama_token llama_vocab::token_unk() const {
return pimpl->special_unk_id;
}

llama_token llama_vocab::token_cls() const {
return pimpl->special_cls_id;
}

llama_token llama_vocab::token_sep() const {
return pimpl->special_sep_id;
}
Expand Down Expand Up @@ -3069,8 +3057,9 @@ llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
return vocab->token_eot();
}

// deprecated
llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
return vocab->token_cls();
return vocab->token_bos();
}

llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
Expand Down Expand Up @@ -3159,7 +3148,8 @@ llama_token llama_token_eot(const struct llama_vocab * vocab) {

// deprecated
llama_token llama_token_cls(const struct llama_vocab * vocab) {
return llama_vocab_cls(vocab);
//return llama_vocab_cls(vocab);
return llama_vocab_bos(vocab); // avoid deprecation warning
}

// deprecated
Expand Down
1 change: 0 additions & 1 deletion src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ struct llama_vocab {
llama_token token_eot() const;
llama_token token_eom() const;
llama_token token_unk() const;
llama_token token_cls() const;
llama_token token_sep() const;
llama_token token_nl () const;
llama_token token_pad() const;
Expand Down
Loading