Skip to content

Commit

Permalink
Add stub register_objects for models
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Sep 26, 2024
1 parent f2fc9b8 commit 9e8db09
Show file tree
Hide file tree
Showing 17 changed files with 417 additions and 393 deletions.
7 changes: 6 additions & 1 deletion src/fairseq2/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@

# isort: split

import fairseq2.models.llama.archs # Register architectures.
from fairseq2.dependency import DependencyContainer
from fairseq2.models.llama.archs import register_archs


def register_objects(container: DependencyContainer) -> None:
register_archs()
171 changes: 81 additions & 90 deletions src/fairseq2/models/llama/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,127 +10,118 @@
from fairseq2.models.llama.factory import LLaMAConfig, llama_arch


@llama_arch("7b")
def _7b() -> LLaMAConfig:
return LLaMAConfig()
def register_archs() -> None:
@llama_arch("7b")
def _7b() -> LLaMAConfig:
return LLaMAConfig()

@llama_arch("13b")
def _13b() -> LLaMAConfig:
config = _7b()

@llama_arch("13b")
def _13b() -> LLaMAConfig:
config = _7b()
config.model_dim = 5120
config.num_attn_heads = 40
config.num_key_value_heads = 40
config.ffn_inner_dim = 5120 * 4

config.model_dim = 5120
config.num_attn_heads = 40
config.num_key_value_heads = 40
config.ffn_inner_dim = 5120 * 4
return config

return config
@llama_arch("33b")
def _33b() -> LLaMAConfig:
config = _7b()

config.model_dim = 6656
config.num_layers = 60
config.num_attn_heads = 52
config.num_key_value_heads = 52
config.ffn_inner_dim = 6656 * 4

@llama_arch("33b")
def _33b() -> LLaMAConfig:
config = _7b()
return config

config.model_dim = 6656
config.num_layers = 60
config.num_attn_heads = 52
config.num_key_value_heads = 52
config.ffn_inner_dim = 6656 * 4
@llama_arch("65b")
def _65b() -> LLaMAConfig:
config = _7b()

return config
config.model_dim = 8192
config.num_layers = 80
config.num_attn_heads = 64
config.num_key_value_heads = 64
config.ffn_inner_dim = 8192 * 4

return config

@llama_arch("65b")
def _65b() -> LLaMAConfig:
config = _7b()
@llama_arch("llama2_7b")
def _llama2_7b() -> LLaMAConfig:
config = _7b()

config.model_dim = 8192
config.num_layers = 80
config.num_attn_heads = 64
config.num_key_value_heads = 64
config.ffn_inner_dim = 8192 * 4
config.max_seq_len = 4096

return config
return config

@llama_arch("llama2_13b")
def _llama2_13b() -> LLaMAConfig:
config = _13b()

@llama_arch("llama2_7b")
def _llama2_7b() -> LLaMAConfig:
config = _7b()
config.max_seq_len = 4096

config.max_seq_len = 4096
return config

return config
@llama_arch("llama2_70b")
def _llama2_70b() -> LLaMAConfig:
config = _65b()

config.max_seq_len = 4096
config.num_key_value_heads = 8
config.ffn_inner_dim = int(8192 * 4 * 1.3) # See A.2.1 in LLaMA 2
config.ffn_inner_dim_to_multiple = 4096

@llama_arch("llama2_13b")
def _llama2_13b() -> LLaMAConfig:
config = _13b()
return config

config.max_seq_len = 4096
@llama_arch("llama3_8b")
def _llama3_8b() -> LLaMAConfig:
config = _llama2_7b()

return config
config.max_seq_len = 8192

config.vocab_info = VocabularyInfo(
size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None
)

@llama_arch("llama2_70b")
def _llama2_70b() -> LLaMAConfig:
config = _65b()
config.num_key_value_heads = 8
config.ffn_inner_dim = int(4096 * 4 * 1.3)
config.ffn_inner_dim_to_multiple = 1024
config.rope_theta = 500_000.0

config.max_seq_len = 4096
config.num_key_value_heads = 8
config.ffn_inner_dim = int(8192 * 4 * 1.3) # See A.2.1 in LLaMA 2
config.ffn_inner_dim_to_multiple = 4096
return config

return config
@llama_arch("llama3_70b")
def _llama3_70b() -> LLaMAConfig:
config = _llama2_70b()

config.max_seq_len = 8192

@llama_arch("llama3_8b")
def _llama3_8b() -> LLaMAConfig:
config = _llama2_7b()
config.vocab_info = VocabularyInfo(
size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None
)

config.max_seq_len = 8192
config.rope_theta = 500_000.0

config.vocab_info = VocabularyInfo(
size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None
)
return config

config.num_key_value_heads = 8
config.ffn_inner_dim = int(4096 * 4 * 1.3)
config.ffn_inner_dim_to_multiple = 1024
config.rope_theta = 500_000.0
@llama_arch("llama3_1_8b")
def _llama3_1_8b() -> LLaMAConfig:
config = _llama3_8b()

return config
config.max_seq_len = 131_072
config.use_scaled_rope = True

return config

@llama_arch("llama3_70b")
def _llama3_70b() -> LLaMAConfig:
config = _llama2_70b()
@llama_arch("llama3_1_70b")
def _llama3_1_70b() -> LLaMAConfig:
config = _llama3_70b()

config.max_seq_len = 8192
config.max_seq_len = 131_072
config.use_scaled_rope = True

config.vocab_info = VocabularyInfo(
size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None
)

config.rope_theta = 500_000.0

return config


@llama_arch("llama3_1_8b")
def _llama3_1_8b() -> LLaMAConfig:
config = _llama3_8b()

config.max_seq_len = 131_072
config.use_scaled_rope = True

return config


@llama_arch("llama3_1_70b")
def _llama3_1_70b() -> LLaMAConfig:
config = _llama3_70b()

config.max_seq_len = 131_072
config.use_scaled_rope = True

return config
return config
7 changes: 6 additions & 1 deletion src/fairseq2/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@

# isort: split

import fairseq2.models.mistral.archs # Register architectures.
from fairseq2.dependency import DependencyContainer
from fairseq2.models.mistral.archs import register_archs


def register_objects(container: DependencyContainer) -> None:
register_archs()
7 changes: 4 additions & 3 deletions src/fairseq2/models/mistral/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fairseq2.models.mistral.factory import MistralConfig, mistral_arch


@mistral_arch("7b")
def _7b() -> MistralConfig:
return MistralConfig()
def register_archs() -> None:
@mistral_arch("7b")
def _7b() -> MistralConfig:
return MistralConfig()
7 changes: 6 additions & 1 deletion src/fairseq2/models/nllb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@

# isort: split

import fairseq2.models.nllb.archs # Register architectures.
from fairseq2.dependency import DependencyContainer
from fairseq2.models.nllb.archs import register_archs


def register_objects(container: DependencyContainer) -> None:
register_archs()
72 changes: 35 additions & 37 deletions src/fairseq2/models/nllb/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,49 @@
from fairseq2.nn.transformer import TransformerNormOrder


@transformer_arch("nllb_dense_300m")
def _dense_300m() -> TransformerConfig:
config = _dense_1b()
def register_archs() -> None:
@transformer_arch("nllb_dense_300m")
def _dense_300m() -> TransformerConfig:
config = _dense_1b()

config.num_encoder_layers = 6
config.num_decoder_layers = 6
config.ffn_inner_dim = 1024 * 4
config.dropout_p = 0.3
config.num_encoder_layers = 6
config.num_decoder_layers = 6
config.ffn_inner_dim = 1024 * 4
config.dropout_p = 0.3

return config
return config

@transformer_arch("nllb_dense_600m")
def _dense_600m() -> TransformerConfig:
config = _dense_1b()

@transformer_arch("nllb_dense_600m")
def _dense_600m() -> TransformerConfig:
config = _dense_1b()
config.num_encoder_layers = 12
config.num_decoder_layers = 12
config.ffn_inner_dim = 1024 * 4

config.num_encoder_layers = 12
config.num_decoder_layers = 12
config.ffn_inner_dim = 1024 * 4
return config

return config
@transformer_arch("nllb_dense_1b")
def _dense_1b() -> TransformerConfig:
config = transformer_archs.get("base")

config.model_dim = 1024
config.vocab_info = VocabularyInfo(
size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
)
config.num_encoder_layers = 24
config.num_decoder_layers = 24
config.num_encoder_attn_heads = 16
config.num_decoder_attn_heads = 16
config.ffn_inner_dim = 1024 * 8
config.norm_order = TransformerNormOrder.PRE

@transformer_arch("nllb_dense_1b")
def _dense_1b() -> TransformerConfig:
config = transformer_archs.get("base")
return config

config.model_dim = 1024
config.vocab_info = VocabularyInfo(
size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
)
config.num_encoder_layers = 24
config.num_decoder_layers = 24
config.num_encoder_attn_heads = 16
config.num_decoder_attn_heads = 16
config.ffn_inner_dim = 1024 * 8
config.norm_order = TransformerNormOrder.PRE
@transformer_arch("nllb_dense_3b")
def _dense_3b() -> TransformerConfig:
config = _dense_1b()

return config
config.model_dim = 2048


@transformer_arch("nllb_dense_3b")
def _dense_3b() -> TransformerConfig:
config = _dense_1b()

config.model_dim = 2048

return config
return config
7 changes: 6 additions & 1 deletion src/fairseq2/models/s2t_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,9 @@

# isort: split

import fairseq2.models.s2t_transformer.archs # Register architectures.
from fairseq2.dependency import DependencyContainer
from fairseq2.models.s2t_transformer.archs import register_archs


def register_objects(container: DependencyContainer) -> None:
register_archs()
Loading

0 comments on commit 9e8db09

Please sign in to comment.