Skip to content

Commit

Permalink
support stop tokens in gen and support ppl
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 30, 2024
1 parent e81b110 commit b53ef44
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
1 change: 0 additions & 1 deletion wenet/LLM/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def forward(
target,
ignore_label=self.ignore_id)

# TODO: ppl
return {
"loss": loss,
"ppl": torch.exp(loss.detach()),
Expand Down
36 changes: 36 additions & 0 deletions wenet/LLM/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Optional


@dataclass
class Template:
# one turn :{system_format}{user_format}{assistant_format}
# multi turns:
# {system_format}{user_format}{assistant_format}{user_format}{assistant_format}...
system: Optional[str]
user: str
assistant: str

bos: str
eos: str


gemma = Template(
'',
'<start_of_turn>user\n{content}<end_of_turn>\n<start_of_turn>model\n',
'{content}<end_of_turn>\n',
'<bos>',
'<eos>',
)

llama3 = Template(
'<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>',
'<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n',
'{content}<|eot_id|>',
'<|begin_of_text|>',
'<|end_of_text|>',
)
WENET_LLM_Template = {
"gemma": gemma,
'llama3': llama3,
}

0 comments on commit b53ef44

Please sign in to comment.