Skip to content

Commit

Permalink
lora llama 3.1 can accept specified max_seq_len
Browse files Browse the repository at this point in the history
  • Loading branch information
akashc1 committed Dec 23, 2024
1 parent aa8f365 commit d2f581c
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torchtune/models/llama3_1/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def lora_llama3_1_8b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
max_seq_len: int = 131072,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
Expand Down Expand Up @@ -125,7 +126,7 @@ def lora_llama3_1_8b(
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
max_seq_len=131072,
max_seq_len=max_seq_len,
intermediate_dim=14336,
attn_dropout=0.0,
norm_eps=1e-5,
Expand All @@ -142,6 +143,7 @@ def lora_llama3_1_70b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
max_seq_len: int = 131072,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
Expand Down Expand Up @@ -182,7 +184,7 @@ def lora_llama3_1_70b(
num_heads=64,
num_kv_heads=8,
embed_dim=8192,
max_seq_len=131072,
max_seq_len=max_seq_len,
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
Expand All @@ -199,6 +201,7 @@ def lora_llama3_1_405b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
max_seq_len: int = 8192,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
Expand Down Expand Up @@ -236,7 +239,7 @@ def lora_llama3_1_405b(
num_heads=128,
num_kv_heads=8,
embed_dim=16384,
max_seq_len=8192,
max_seq_len=max_seq_len,
intermediate_dim=53248,
attn_dropout=0.0,
norm_eps=1e-5,
Expand Down

0 comments on commit d2f581c

Please sign in to comment.