diff --git a/torchtune/models/llama3_1/_model_builders.py b/torchtune/models/llama3_1/_model_builders.py index b6439b2eb2..d3ddd3cedd 100644 --- a/torchtune/models/llama3_1/_model_builders.py +++ b/torchtune/models/llama3_1/_model_builders.py @@ -85,6 +85,7 @@ def lora_llama3_1_8b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, + max_seq_len: int = 131072, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, @@ -125,7 +126,7 @@ def lora_llama3_1_8b( num_heads=32, num_kv_heads=8, embed_dim=4096, - max_seq_len=131072, + max_seq_len=max_seq_len, intermediate_dim=14336, attn_dropout=0.0, norm_eps=1e-5, @@ -142,6 +143,7 @@ def lora_llama3_1_70b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, + max_seq_len: int = 131072, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, @@ -182,7 +184,7 @@ def lora_llama3_1_70b( num_heads=64, num_kv_heads=8, embed_dim=8192, - max_seq_len=131072, + max_seq_len=max_seq_len, intermediate_dim=28672, attn_dropout=0.0, norm_eps=1e-5, @@ -199,6 +201,7 @@ def lora_llama3_1_405b( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool = False, apply_lora_to_output: bool = False, + max_seq_len: int = 8192, lora_rank: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, @@ -236,7 +239,7 @@ def lora_llama3_1_405b( num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=max_seq_len, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5,