Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Falcon3 support and Fix issue #10875 #10883

Merged
merged 7 commits into from
Dec 22, 2024

Conversation

mokeddembillel
Copy link
Contributor

This PR adds Falcon3 support and fixes issue #10875 caused by previous PR #10864 (see #10864 for details)

Details of fixing issue #10875:

The issue is that when using meta-llama/Llama-3.1-8B-Instruct the <|begin_of_text|> token is added to every special token when doing token = tokenizer.decode(tokenizer.encode(token))

the screenshot shows before and after token = tokenizer.decode(tokenizer.encode(token))
image

I'm fixing this by adding add_special_tokens=False to tokenizer.encode(). Here is the the result after the fix
image

to be extra safe, we will use token = tokenizer.decode(tokenizer.encode(token)) only if len(token) == 1 so that still fix this issue when \n is econded as Ċ

Generation before the fix:

Prompt: Once upon a time in a land far away,
there was a kingdom ruled by a wise and just king. The kingdom was known for its beauty and prosperity, and the people lived in peace and harmony.ĊĊOne day, a terrible drought struck the land, and the crops began to wither and die. The king, worried about the well-being of his people, called upon his wise council to find a solution. The council, after much deliberation, decided to send a group of brave knights to search for a magical spring that was said to have the power to bring rain to the kingdom.

Generation after the fix:

Prompt: Once upon a time in a land far away,
there was a kingdom ruled by a wise and just king. The kingdom was known for its beauty and prosperity, and the people lived in peace and harmony.

One day, a terrible drought struck the land, and the crops began to wither and die. The king, worried about the well-being of his people, called upon his wise council to find a solution. The council, after much deliberation, decided to send a group of brave knights to search for a magical spring that was said to have the power to bring rain to the kingdom.

@ggerganov @compilade @slaren

@github-actions github-actions bot added the python python script changes label Dec 18, 2024
@@ -525,6 +525,11 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# We need to manually encode and decode the added tokens in case special characters
# used for `\n` / `\t` have been manually added in the added tokens
if len(token) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(token) == 1:
# To avoid unexpected issues - we make sure to encode single-char tokens
if len(token) == 1:

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at the Falcon tokenizer and I don't see any added tokens that have \n or \t: https://huggingface.co/tiiuae/Falcon3-7B-Instruct/raw/main/tokenizer.json

For which tokens does this change make a difference?

Maybe also add some logs to know when this path is being triggered so we can spot any potential problems with other models.

Copy link
Contributor

@younesbelkada younesbelkada Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chiming in here ! The added token is

    {
      "id": 12,
      "content": "Ċ",
      "single_word": false,
      "lstrip": false,
      "rstrip": false,
      "normalized": false,
      "special": false
    }

(\t is the id 13)
the only way to convert it properly to \n is to encode / decode using the tokenizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added a log message inside the if statement.

# used for `\n` / `\t` have been manually added in the added tokens
# To avoid unexpected issues - we make sure to encode single-char tokens
if len(token) == 1:
logger.info("Ecode-Decode special characters using AutoTokenizer")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about comparing the token before and after the encoding and print the log only if there is a difference.

Copy link
Contributor Author

@mokeddembillel mokeddembillel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good idea. Done!

INFO:hf-to-gguf:'Ċ' is encoded and decoded back to '\n' using AutoTokenizer
INFO:hf-to-gguf:'ĉ' is encoded and decoded back to '\t' using AutoTokenizer

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems OK to me, but I am not sure about the full implications of this change for all other models. Want to wait for some feedback from the community.

The alternative is to find a way to apply this logic only inside the class FalconModel.

@ggerganov ggerganov requested a review from compilade December 18, 2024 08:41
@mokeddembillel
Copy link
Contributor Author

mokeddembillel commented Dec 18, 2024

Actually there's no FalconModel class and our model type is llama so we can't use that to check. The only solution I see is that we wait for some feedback from the community and if there's any error related to this, I will be happy to address it and fix it quickly.

@compilade
Copy link
Collaborator

compilade commented Dec 18, 2024

but I am not sure about the full implications of this change for all other models.

This can be tested by converting all tokenizers fetched by convert_hf_to_gguf_update.py and comparing the hashes when converted before and after this change (which I can't do right now, but will when I can).

I think what would solve variations of this problem for other models in the future (for another PR) would be to either normalize all added tokens which are marked "normalized": false, since the added tokens are internally assumed to be pre-normalized (this is the same problem which #8228 attempted to solve, but apparently the fix wasn't general enough (it only normalizes "▁" to " ", which solved this problem for Gemma)), or non-normalized added tokens could be internally handled by adding some token attribute for it. Though this would depend on proper support for token attributes stored in GGUF files, which isn't yet complete (even though per-token attributes were added in #7685, they aren't stored in GGUF models, and LLAMA_TOKEN_ATTR_NORMALIZED isn't really handled).

@younesbelkada
Copy link
Contributor

Thanks @compilade !
There might be an easier solution, I am about to manually modify the normalized characters (since the problem is only for \n and \t that have been explciitly added as special tokens) in the tokenizer file and push the normalized tokenizer on all repos - then we can convert this PR to simply adding the falcon3 pre-tokenizer - what do you think?

@compilade
Copy link
Collaborator

There might be an easier solution, I am about to manually modify the normalized characters (since the problem is only for \n and \t that have been explciitly added as special tokens) in the tokenizer file and push the normalized tokenizer on all repos - then we can convert this PR to simply adding the falcon3 pre-tokenizer - what do you think?

@younesbelkada

That could also work, as long as it's done correctly. The added tokens are in both tokenizers.json and tokenizer_config.json. If you do this, make sure that it doesn't have unintended consequences.

This is otherwise a nice edge case I think the convert scripts should have handled correctly, so part of me wants to keep the tokenizers the same.

@younesbelkada
Copy link
Contributor

Perfect thanks, will test that out and update here !

@younesbelkada
Copy link
Contributor

younesbelkada commented Dec 18, 2024

@compilade I just did some tests and I think we can't go with the solution I suggested above - mainly due to backward compatibility reasons - Before the manual changes

(Pdb) tok.encode("ĉ")
[13]

After the fix I suggested:

(Pdb) fixed_tokenizer.encode("ĉ")
[2150, 2237]

--> For the same token we now get different encodings - As all falcon3 series models have been trained with that tokenizer, even if the probability that this token appears in a text, I am afraid it's a way too risky breaking change to introduce..
I also tried to set normalize: true to these tokens and converted the model with this PR and still getting Ċ printed all over the place for line breaks.

Perhaps we can test if existing tokenizers are not affected by this PR, what do you think? Happy to help you on this as well

@xhatz
Copy link

xhatz commented Dec 22, 2024

Hey, I'm still getting this error:

error loading model vocabulary: unknown pre-tokenizer type: 'falcon3'

Is it normal? I've read on a previous llamacpp version that is was updated to handle it. Thanks!

@mokeddembillel
Copy link
Contributor Author

@xhatz we need to wait until this pull request is merged.

@compilade is there any way we can help to merge this pull request faster? Thanks

@xhatz
Copy link

xhatz commented Dec 22, 2024

Oh alright, thanks! I had read it was already merged in b4341 that's why I was confused haha.

@ggerganov
Copy link
Owner

This can be tested by converting all tokenizers fetched by convert_hf_to_gguf_update.py and comparing the hashes when converted before and after this change (which I can't do right now, but will when I can).

Which hashes do you have in mind? The hashes produced by convert_hf_to_gguf_update.py are generated by the original Python tokenizer - it is not affected by this change, so they would be the same.

I reconverted all tokenizers that I have access to:

python3.11 convert_hf_to_gguf.py models/tokenizers/llama-spm/ --outfile models/ggml-vocab-llama-spm.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/llama-bpe/ --outfile models/ggml-vocab-llama-bpe.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/phi-3/ --outfile models/ggml-vocab-phi-3.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/deepseek-llm/ --outfile models/ggml-vocab-deepseek-llm.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/deepseek-coder/ --outfile models/ggml-vocab-deepseek-coder.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/falcon/ --outfile models/ggml-vocab-falcon.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/bert-bge/ --outfile models/ggml-vocab-bert-bge.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/falcon3/ --outfile models/ggml-vocab-falcon3.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/bert-bge-large/ --outfile models/ggml-vocab-bert-bge-large.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/mpt/ --outfile models/ggml-vocab-mpt.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/starcoder/ --outfile models/ggml-vocab-starcoder.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/gpt-2/ --outfile models/ggml-vocab-gpt-2.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/stablelm2/ --outfile models/ggml-vocab-stablelm2.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/refact/ --outfile models/ggml-vocab-refact.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/command-r/ --outfile models/ggml-vocab-command-r.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/qwen2/ --outfile models/ggml-vocab-qwen2.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/olmo/ --outfile models/ggml-vocab-olmo.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/dbrx/ --outfile models/ggml-vocab-dbrx.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/jina-v1-en/ --outfile models/ggml-vocab-jina-v1-en.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/jina-v2-en/ --outfile models/ggml-vocab-jina-v2-en.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/jina-v2-es/ --outfile models/ggml-vocab-jina-v2-es.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/jina-v2-de/ --outfile models/ggml-vocab-jina-v2-de.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/smaug-bpe/ --outfile models/ggml-vocab-smaug-bpe.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/poro-chat/ --outfile models/ggml-vocab-poro-chat.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/jina-v2-code/ --outfile models/ggml-vocab-jina-v2-code.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/viking/ --outfile models/ggml-vocab-viking.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/gemma/ --outfile models/ggml-vocab-gemma.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/gemma-2/ --outfile models/ggml-vocab-gemma-2.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/t5/ --outfile models/ggml-vocab-t5.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/codeshell/ --outfile models/ggml-vocab-codeshell.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/tekken/ --outfile models/ggml-vocab-tekken.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/smollm/ --outfile models/ggml-vocab-smollm.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/bloom/ --outfile models/ggml-vocab-bloom.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/gpt3-finnish/ --outfile models/ggml-vocab-gpt3-finnish.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/phi-2/ --outfile models/ggml-vocab-phi-2.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/minerva-7b/ --outfile models/ggml-vocab-minerva-7b.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/roberta-bpe/ --outfile models/ggml-vocab-roberta-bpe.gguf --vocab-only
python3.11 convert_hf_to_gguf.py models/tokenizers/gigachat/ --outfile models/ggml-vocab-gigachat.gguf --vocab-only

The new log message was triggered for the following tokenizers:

INFO:hf-to-gguf:Loading model: deepseek-llm
INFO:hf-to-gguf:'ø' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ö' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ú' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ÿ' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'õ' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'÷' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'û' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ý' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'À' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ù' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'Á' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'þ' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ü' is encoded and decoded back to '�' using AutoTokenizer

INFO:hf-to-gguf:Loading model: deepseek-coder
INFO:hf-to-gguf:'õ' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'÷' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'Á' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ý' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'À' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ÿ' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ø' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ú' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'þ' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ü' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ù' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'ö' is encoded and decoded back to '�' using AutoTokenizer
INFO:hf-to-gguf:'û' is encoded and decoded back to '�' using AutoTokenizer

INFO:hf-to-gguf:Loading model: falcon3
INFO:hf-to-gguf:'Ċ' is encoded and decoded back to '\n' using AutoTokenizer
INFO:hf-to-gguf:'ĉ' is encoded and decoded back to '\t' using AutoTokenizer

So it seems Deepseek models will be affected and I cannot tell if the change is good or bad. Any ideas?

Also, there might be more models affected for which I haven't requested access yet: jais, exaone, chameleon.

@mokeddembillel
Copy link
Contributor Author

I'm looking into it. Thanks!

@compilade
Copy link
Collaborator

compilade commented Dec 22, 2024

This can be tested by converting all tokenizers fetched by convert_hf_to_gguf_update.py and comparing the hashes when converted before and after this change (which I can't do right now, but will when I can).

Which hashes do you have in mind? The hashes produced by convert_hf_to_gguf_update.py are generated by the original Python tokenizer - it is not affected by this change, so they would be the same.

I meant the hashes of the GGUF vocab-only models made when reconverting the tokenizers with convert_hf_to_gguf.py.

I've done this hash comparison test for the below suggested change.
$ # Comparing after vs before, because it was more convenient
$ sha256sum --check tokenizers-after.SHA256SUMS 
models/ggml-vocab-bert-bge.gguf: OK
models/ggml-vocab-bert-bge-large.gguf: OK
models/ggml-vocab-bloom.gguf: OK
models/ggml-vocab-codeshell.gguf: OK
models/ggml-vocab-command-r.gguf: OK
models/ggml-vocab-dbrx.gguf: OK
models/ggml-vocab-deepseek-coder.gguf: OK
models/ggml-vocab-deepseek-llm.gguf: OK
models/ggml-vocab-falcon3.gguf: FAILED
models/ggml-vocab-falcon.gguf: OK
models/ggml-vocab-gemma-2.gguf: OK
models/ggml-vocab-gemma.gguf: OK
models/ggml-vocab-gigachat.gguf: OK
models/ggml-vocab-gpt-2.gguf: OK
models/ggml-vocab-gpt3-finnish.gguf: OK
models/ggml-vocab-jais.gguf: OK
models/ggml-vocab-jina-v1-en.gguf: OK
models/ggml-vocab-jina-v2-code.gguf: OK
models/ggml-vocab-jina-v2-de.gguf: OK
models/ggml-vocab-jina-v2-en.gguf: OK
models/ggml-vocab-jina-v2-es.gguf: OK
models/ggml-vocab-llama-bpe.gguf: OK
models/ggml-vocab-llama-spm.gguf: OK
models/ggml-vocab-minerva-7b.gguf: OK
models/ggml-vocab-mpt.gguf: OK
models/ggml-vocab-olmo.gguf: OK
models/ggml-vocab-phi-2.gguf: OK
models/ggml-vocab-phi-3.gguf: OK
models/ggml-vocab-poro-chat.gguf: OK
models/ggml-vocab-qwen2.gguf: OK
models/ggml-vocab-refact.gguf: OK
models/ggml-vocab-roberta-bpe.gguf: OK
models/ggml-vocab-smaug-bpe.gguf: OK
models/ggml-vocab-smollm.gguf: OK
models/ggml-vocab-stablelm2.gguf: OK
models/ggml-vocab-starcoder.gguf: OK
models/ggml-vocab-t5.gguf: OK
models/ggml-vocab-tekken.gguf: OK
models/ggml-vocab-viking.gguf: OK
sha256sum: WARNING: 1 computed checksum did NOT match

$ cat tokenizers-after.SHA256SUMS
1c0527faffc570709debe95200853c7f603be409786fda3d4f0aa6bc3a0bd9e2  models/ggml-vocab-bert-bge.gguf
2be2b6512d25a7ac34c44c4de6e0594ccdd23db0964e3a2792dd07f72f5001b3  models/ggml-vocab-bert-bge-large.gguf
91c99e3e5a9816e8e7494e290deecac4adec7fa29548f2ce5f969a720ee222cb  models/ggml-vocab-bloom.gguf
b4755bac7dde130e0b537db32dc2e47abc8acafd7ad55330f822728787280f94  models/ggml-vocab-codeshell.gguf
e7961be1843e319ea1f50d633a708f6e7b87358aa75ee51d1ec817c2bdc64b1b  models/ggml-vocab-command-r.gguf
ae85dca38568f521ef166d3fd7c36bf42c310520e1b14cdbbe002fc63f55fece  models/ggml-vocab-dbrx.gguf
18eda84b19d8e26cd2628dc20a6e272dd228d8201218ef33ff401ec92b97a1ee  models/ggml-vocab-deepseek-coder.gguf
aab19f2513d2a9a74ea6b7d98b0fa4292bc18d36a90d82f392d485b727704688  models/ggml-vocab-deepseek-llm.gguf
d7004d54f98e1be46b16535d782aa61314bc24e041a90b34421c324e85b8d062  models/ggml-vocab-falcon3.gguf
2ce892822f0aed14c47ef86c7881600002edff1bc1ecaf85e49a468ebeb45296  models/ggml-vocab-falcon.gguf
894c712b2ab13280bd7d98530ec1ffb90b82cea840bba2156c01e3bf1e8f690e  models/ggml-vocab-gemma-2.gguf
384f6fc4138182c0b0fd7e555b217286ee03cab7c0683c58894c2847dd082f04  models/ggml-vocab-gemma.gguf
8726d0a2bf0f3b3f11493167daa4fc4bcdf2fbbc9453614ac01dd41b121991e0  models/ggml-vocab-gigachat.gguf
458a70ee49a538a345687b3741a3be97065da250ca809528535cd41b1521837a  models/ggml-vocab-gpt-2.gguf
f4467c52a30fd90c3af8427e54569860e6ba0d9cf2a27b0d891024689da95975  models/ggml-vocab-gpt3-finnish.gguf
0a29abf5ad3fc3e9610fb7e0b431635ac4be0e7932f35565b55bba0c36f5b1d0  models/ggml-vocab-jais.gguf
453bf7aba167509a2bd4540cce634c3295839623da05212ea6f5a0129c403280  models/ggml-vocab-jina-v1-en.gguf
a42605bfb68ef166460c392cc158dc0999bbae2a19a18038d798b445ea6f173d  models/ggml-vocab-jina-v2-code.gguf
498b52df737a57a46e020fe949d92eaf81bbadf05f5bce85c7de1e76feed317c  models/ggml-vocab-jina-v2-de.gguf
813424b56165f02454e4dc253fc8fd156cf901c50e181283964e35494fe8efc2  models/ggml-vocab-jina-v2-en.gguf
dcdeda4eda3c26aac31f12035303cb42197a97ed7a16354f2525008ee042121e  models/ggml-vocab-jina-v2-es.gguf
696362e92fe6d64608c191f7ba58a64e84ab89824b1df2ecd72afe0ed4a5495e  models/ggml-vocab-llama-bpe.gguf
277e4666dcaa27ce5cf037468c8249313b3aa4fbb60ded553759f2aa8ebd183d  models/ggml-vocab-llama-spm.gguf
cf222edd862bdc98f96bd8bdd5193d866d27e8c907beb59b47d11c0e5dd54ca9  models/ggml-vocab-minerva-7b.gguf
f6d3ba0807aeebda44115be07b0f3fd428e799a3cc7c540aec8e78afda2927c9  models/ggml-vocab-mpt.gguf
e24719416b0b56a3c7113aefd4f9a247fd8f97122de1b33002ad80854b150123  models/ggml-vocab-olmo.gguf
a8d83c8e39ea6422a5afc4e46d7e38d69ea2b55073164ea5b11bdc52cfeabe1a  models/ggml-vocab-phi-2.gguf
61d4c3e859e7dc3c147be882675f1e1e75f8f5e656a1166d944919abee81d2f7  models/ggml-vocab-phi-3.gguf
c9de0ca0375fca5407039efa54b6729a42c3d73d0396f0ec4f26a933b46271c4  models/ggml-vocab-poro-chat.gguf
fcea91caaa7db69a7dba2e0e492a8211f7179f1d48860d6ab8df763e2e14e9e3  models/ggml-vocab-qwen2.gguf
4fe61fdab2b3dc7dda164c8b96f7065be0bff59ea8de2162d79485abf327e984  models/ggml-vocab-refact.gguf
41a36dcf7a625af366c73773d09b542cce776f4feae0c9915172e77b6dd15e20  models/ggml-vocab-roberta-bpe.gguf
49b329ee6ed85a1118dc32f64d7327f8ea44faf897dd4815521a47d0768481f3  models/ggml-vocab-smaug-bpe.gguf
5ddf2bf51313df873df36de94f288535ba727c1bbe77a54a9eea083b66d36afc  models/ggml-vocab-smollm.gguf
db0db920f3d7680d6521d0bd9e5013ab18ababadfb5ecc44049623b2ddc1c30e  models/ggml-vocab-stablelm2.gguf
3ac97c89ff19b7fa72c47af756b8126d557956ff73759fc15b78ef834d7208a1  models/ggml-vocab-starcoder.gguf
2538d6b3c16f2df514e3e6558cdc446e71e7436f46b128b3de98fe7b4918612f  models/ggml-vocab-t5.gguf
28c153c354066ee9bde889a1c50ff3cb9baf3f8319028f318d4bb048c182304f  models/ggml-vocab-tekken.gguf
0907fbd3cc45102604ee941c9b5bb1ed31aa267aab4507d7f77ba1f51702c9a8  models/ggml-vocab-viking.gguf

$ sha256sum models/ggml-vocab-falcon3.gguf # before the change
1c4d5a81280c46000ef5ea8d449fc58c51659f57a8128886eec9b437cb84a5b1  models/ggml-vocab-falcon3.gguf

And only falcon3 doesn't match, which is expected. (NOTE: I've tested the change first, so the SHA256SUMS file contains the hashes after the change)


I think what would solve variations of this problem for other models in the future (for another PR) would be to either normalize all added tokens which are marked "normalized": false,

I've implemented this suggestion (which turned out to be a single line of code changed), and I think it works correctly.

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 06e3016c..894c3e4d 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -525,10 +525,9 @@ class Model:
             else:
                 token: str = reverse_vocab[i]
                 if token in added_vocab:
-                    # We need to manually encode and decode the added tokens in case special characters
-                    # used for `\n` / `\t` have been manually added in the added tokens
-                    # To avoid unexpected issues - we make sure to encode single-char tokens
-                    if len(token) == 1:
+                    # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
+                    # To avoid unexpected issues - we make sure to normalize non-normalized tokens
+                    if not tokenizer.added_tokens_decoder[i].normalized:
                         previous_token = token
                         token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
                         if previous_token != token:
@@ -537,6 +536,8 @@ class Model:
                     if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
                         toktypes.append(gguf.TokenType.CONTROL)
                     else:
+                        # NOTE: this was added for Gemma.
+                        # Encoding and decoding the tokens above isn't sufficient for this case.
                         token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ")  # pre-normalize user-defined spaces
                         toktypes.append(gguf.TokenType.USER_DEFINED)
                 else:

The deepseek models are not affected anymore, since they properly mark their tokens as normalized.

@ggerganov
Copy link
Owner

Nice. @mokeddembillel After applying the patch by @compilade we can merge.

@mokeddembillel
Copy link
Contributor Author

Thanks a lot for the help @compilade and @ggerganov. Really appreciated🙏🏼

@ggerganov
Copy link
Owner

No problem - do you want me to apply the patch, or will you do it?

@mokeddembillel
Copy link
Contributor Author

Oh yes for sure. I will do it as soon as I get home. Thanks 🙏🏼

src/llama.cpp Outdated
Comment on lines 6478 to 6482
} else if (
tokenizer_pre == "falcon3") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
vocab.tokenizer_ignore_merges = true;
vocab.tokenizer_add_bos = true;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify this by moving the check to the LLAMA3 branch above:

                    tokenizer_pre == "llama3"    ||
                    tokenizer_pre == "llama-v3"  ||
                    tokenizer_pre == "llama-bpe" ||
                    tokenizer_pre == "falcon3") {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@ggerganov ggerganov merged commit 7ae33a6 into ggerganov:master Dec 22, 2024
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants