Skip to content

Commit

Permalink
Support single unified BPE tokenizer for Canary2
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Zelasko <[email protected]>
  • Loading branch information
pzelasko committed Jan 10, 2025
1 parent d8f4dc7 commit bff9690
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion nemo/collections/common/prompts/canary2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CANARY_BOS,
CANARY_EOS,
CANARY_SPECIAL_TOKENIZER,
CanaryTokenizer,
)


Expand Down Expand Up @@ -196,8 +197,13 @@ def canary2(cut: Cut, prompt: Canary2PromptFormatter) -> dict[str, torch.Tensor]
),
)
ans = prompt.encode_dialog(turns)
if isinstance(prompt.tokenizer, CanaryTokenizer):
eos = prompt.tokenizer.eos
else: # SPE
eos = prompt.tokenizer.token_to_id(CANARY_EOS)
assert eos > -1, "Invalid tokenizer: tokenizer.token_to_id('{CANARY_EOS}') returned {eos}"
assert (
ans["answer_ids"][-1].item() == prompt.tokenizer.eos
ans["answer_ids"][-1].item() == eos
), f"Expected the last token in answer_ids to be EOS, but we got {ans['answer_ids']}"
ans["answer_ids"] = ans["answer_ids"][:-1] # Strip Canary's EOS
return ans

0 comments on commit bff9690

Please sign in to comment.