Skip to content

Commit

Permalink
rope base float -> int
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Nov 8, 2024
1 parent a36dde5 commit 4c37a9d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
4 changes: 2 additions & 2 deletions torchtune/models/llama3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def lora_llama3(
intermediate_dim: Optional[int] = None,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
rope_base: float = 500000.0,
rope_base: int = 500_000,
# LoRA args
lora_rank: int,
lora_alpha: float,
Expand Down Expand Up @@ -297,7 +297,7 @@ def lora_llama3_self_attention(
num_kv_heads: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: float = 500000.0,
rope_base: int = 500_000,
# LoRA args
lora_rank: int,
lora_alpha: float,
Expand Down
46 changes: 31 additions & 15 deletions torchtune/models/llama3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
from functools import partial
from typing import List, Optional

from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType

from torchtune.models.llama3._component_builders import llama3, lora_llama3
from torchtune.models.llama3._tokenizer import Llama3Tokenizer

from torchtune.modules import TransformerDecoder
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
from torchtune.data._prompt_templates import _TemplateType
from torchtune.data._prompt_templates import _get_prompt_template


"""
Expand All @@ -40,7 +40,7 @@ def llama3_8b() -> TransformerDecoder:
intermediate_dim=14336,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500000.0,
rope_base=500_000,
)


Expand All @@ -61,32 +61,48 @@ def llama3_70b() -> TransformerDecoder:
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500000.0,
rope_base=500_000,
)


def llama3_tokenizer(path: str, special_tokens_path: Optional[str] = None, max_seq_len: Optional[int] = None, prompt_template: Optional[_TemplateType] = None) -> Llama3Tokenizer:

def llama3_tokenizer(
path: str,
special_tokens_path: Optional[str] = None,
max_seq_len: Optional[int] = None,
prompt_template: Optional[_TemplateType] = None,
) -> Llama3Tokenizer:
"""
Tokenizer for Llama3.
Args:
path (str): path to the tokenizer
special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face
model files that contains all registered special tokens, or a local json file
model files that contains all registered special tokens, or a local json file
structured similarly. Default is None to use the canonical Llama3 special tokens.
max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages,
after which the input will be truncated. Default is None.
prompt_template (Optional[_TemplateType]): optional specified prompt template.
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
prepend/append tags.
Returns:
Llama3Tokenizer: Instantiation of the Llama3 tokenizer
"""
special_tokens = parse_hf_tokenizer_json(special_tokens_path) if special_tokens_path is not None else None
template = _get_prompt_template(prompt_template) if prompt_template is not None else None
return Llama3Tokenizer(path=path, special_tokens=special_tokens, max_seq_len=max_seq_len, prompt_template=template)
special_tokens = (
parse_hf_tokenizer_json(special_tokens_path)
if special_tokens_path is not None
else None
)
template = (
_get_prompt_template(prompt_template) if prompt_template is not None else None
)
return Llama3Tokenizer(
path=path,
special_tokens=special_tokens,
max_seq_len=max_seq_len,
prompt_template=template,
)


def lora_llama3_8b(
Expand Down Expand Up @@ -137,7 +153,7 @@ def lora_llama3_8b(
intermediate_dim=14336,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500000.0,
rope_base=500_000,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
Expand Down Expand Up @@ -194,7 +210,7 @@ def lora_llama3_70b(
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500000.0,
rope_base=500_000,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
Expand Down
8 changes: 4 additions & 4 deletions torchtune/models/llama3_2_vision/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def llama3_2_vision_11b(
embed_dim=4096,
max_seq_len=131_072,
encoder_max_seq_len=128_080, # 20*6404
rope_base=500000.0,
rope_base=500_000,
intermediate_dim=14336,
)
return DeepFusionModel(
Expand Down Expand Up @@ -209,7 +209,7 @@ def lora_llama3_2_vision_11b(
embed_dim=4096,
max_seq_len=131_072,
encoder_max_seq_len=128_080, # 20*6404
rope_base=500000.0,
rope_base=500_000,
intermediate_dim=14336,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
Expand Down Expand Up @@ -266,7 +266,7 @@ def llama3_2_vision_90b(
embed_dim=8192,
max_seq_len=131_072,
encoder_max_seq_len=128_080, # 20*6404
rope_base=500000.0,
rope_base=500_000,
intermediate_dim=28672,
)
return DeepFusionModel(
Expand Down Expand Up @@ -364,7 +364,7 @@ def lora_llama3_2_vision_90b(
embed_dim=8192,
max_seq_len=131_072,
encoder_max_seq_len=128_080, # 20*6404
rope_base=500000.0,
rope_base=500_000,
intermediate_dim=28672,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
Expand Down

0 comments on commit 4c37a9d

Please sign in to comment.