From bff9690c2a58192e6e167f978b040ed06d4b2e94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 10 Jan 2025 08:03:34 -0800 Subject: [PATCH] Support single unified BPE tokenizer for Canary2 Signed-off-by: Piotr Zelasko --- nemo/collections/common/prompts/canary2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nemo/collections/common/prompts/canary2.py b/nemo/collections/common/prompts/canary2.py index 3aed7a3bfa10..2aa657d294cc 100644 --- a/nemo/collections/common/prompts/canary2.py +++ b/nemo/collections/common/prompts/canary2.py @@ -26,6 +26,7 @@ CANARY_BOS, CANARY_EOS, CANARY_SPECIAL_TOKENIZER, + CanaryTokenizer, ) @@ -196,8 +197,13 @@ def canary2(cut: Cut, prompt: Canary2PromptFormatter) -> dict[str, torch.Tensor] ), ) ans = prompt.encode_dialog(turns) + if isinstance(prompt.tokenizer, CanaryTokenizer): + eos = prompt.tokenizer.eos + else: # SPE + eos = prompt.tokenizer.token_to_id(CANARY_EOS) + assert eos > -1, "Invalid tokenizer: tokenizer.token_to_id('{CANARY_EOS}') returned {eos}" assert ( - ans["answer_ids"][-1].item() == prompt.tokenizer.eos + ans["answer_ids"][-1].item() == eos ), f"Expected the last token in answer_ids to be EOS, but we got {ans['answer_ids']}" ans["answer_ids"] = ans["answer_ids"][:-1] # Strip Canary's EOS return ans