diff --git a/src/fairseq2/models/llama/__init__.py b/src/fairseq2/models/llama/__init__.py index eddf9ba1e..1a5c29aa8 100644 --- a/src/fairseq2/models/llama/__init__.py +++ b/src/fairseq2/models/llama/__init__.py @@ -23,4 +23,9 @@ # isort: split -import fairseq2.models.llama.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.llama.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/llama/archs.py b/src/fairseq2/models/llama/archs.py index 103ae58b8..9db189202 100644 --- a/src/fairseq2/models/llama/archs.py +++ b/src/fairseq2/models/llama/archs.py @@ -10,127 +10,118 @@ from fairseq2.models.llama.factory import LLaMAConfig, llama_arch -@llama_arch("7b") -def _7b() -> LLaMAConfig: - return LLaMAConfig() +def register_archs() -> None: + @llama_arch("7b") + def _7b() -> LLaMAConfig: + return LLaMAConfig() + @llama_arch("13b") + def _13b() -> LLaMAConfig: + config = _7b() -@llama_arch("13b") -def _13b() -> LLaMAConfig: - config = _7b() + config.model_dim = 5120 + config.num_attn_heads = 40 + config.num_key_value_heads = 40 + config.ffn_inner_dim = 5120 * 4 - config.model_dim = 5120 - config.num_attn_heads = 40 - config.num_key_value_heads = 40 - config.ffn_inner_dim = 5120 * 4 + return config - return config + @llama_arch("33b") + def _33b() -> LLaMAConfig: + config = _7b() + config.model_dim = 6656 + config.num_layers = 60 + config.num_attn_heads = 52 + config.num_key_value_heads = 52 + config.ffn_inner_dim = 6656 * 4 -@llama_arch("33b") -def _33b() -> LLaMAConfig: - config = _7b() + return config - config.model_dim = 6656 - config.num_layers = 60 - config.num_attn_heads = 52 - config.num_key_value_heads = 52 - config.ffn_inner_dim = 6656 * 4 + @llama_arch("65b") + def _65b() -> LLaMAConfig: + config = _7b() - return config + config.model_dim = 8192 + config.num_layers = 80 + config.num_attn_heads = 64 + config.num_key_value_heads = 64 + config.ffn_inner_dim = 8192 * 4 + return config -@llama_arch("65b") -def _65b() -> LLaMAConfig: - config = _7b() + @llama_arch("llama2_7b") + def _llama2_7b() -> LLaMAConfig: + config = _7b() - config.model_dim = 8192 - config.num_layers = 80 - config.num_attn_heads = 64 - config.num_key_value_heads = 64 - config.ffn_inner_dim = 8192 * 4 + config.max_seq_len = 4096 - return config + return config + @llama_arch("llama2_13b") + def _llama2_13b() -> LLaMAConfig: + config = _13b() -@llama_arch("llama2_7b") -def _llama2_7b() -> LLaMAConfig: - config = _7b() + config.max_seq_len = 4096 - config.max_seq_len = 4096 + return config - return config + @llama_arch("llama2_70b") + def _llama2_70b() -> LLaMAConfig: + config = _65b() + config.max_seq_len = 4096 + config.num_key_value_heads = 8 + config.ffn_inner_dim = int(8192 * 4 * 1.3) # See A.2.1 in LLaMA 2 + config.ffn_inner_dim_to_multiple = 4096 -@llama_arch("llama2_13b") -def _llama2_13b() -> LLaMAConfig: - config = _13b() + return config - config.max_seq_len = 4096 + @llama_arch("llama3_8b") + def _llama3_8b() -> LLaMAConfig: + config = _llama2_7b() - return config + config.max_seq_len = 8192 + config.vocab_info = VocabularyInfo( + size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None + ) -@llama_arch("llama2_70b") -def _llama2_70b() -> LLaMAConfig: - config = _65b() + config.num_key_value_heads = 8 + config.ffn_inner_dim = int(4096 * 4 * 1.3) + config.ffn_inner_dim_to_multiple = 1024 + config.rope_theta = 500_000.0 - config.max_seq_len = 4096 - config.num_key_value_heads = 8 - config.ffn_inner_dim = int(8192 * 4 * 1.3) # See A.2.1 in LLaMA 2 - config.ffn_inner_dim_to_multiple = 4096 + return config - return config + @llama_arch("llama3_70b") + def _llama3_70b() -> LLaMAConfig: + config = _llama2_70b() + config.max_seq_len = 8192 -@llama_arch("llama3_8b") -def _llama3_8b() -> LLaMAConfig: - config = _llama2_7b() + config.vocab_info = VocabularyInfo( + size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None + ) - config.max_seq_len = 8192 + config.rope_theta = 500_000.0 - config.vocab_info = VocabularyInfo( - size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None - ) + return config - config.num_key_value_heads = 8 - config.ffn_inner_dim = int(4096 * 4 * 1.3) - config.ffn_inner_dim_to_multiple = 1024 - config.rope_theta = 500_000.0 + @llama_arch("llama3_1_8b") + def _llama3_1_8b() -> LLaMAConfig: + config = _llama3_8b() - return config + config.max_seq_len = 131_072 + config.use_scaled_rope = True + return config -@llama_arch("llama3_70b") -def _llama3_70b() -> LLaMAConfig: - config = _llama2_70b() + @llama_arch("llama3_1_70b") + def _llama3_1_70b() -> LLaMAConfig: + config = _llama3_70b() - config.max_seq_len = 8192 + config.max_seq_len = 131_072 + config.use_scaled_rope = True - config.vocab_info = VocabularyInfo( - size=128_256, unk_idx=None, bos_idx=128_000, eos_idx=128_001, pad_idx=None - ) - - config.rope_theta = 500_000.0 - - return config - - -@llama_arch("llama3_1_8b") -def _llama3_1_8b() -> LLaMAConfig: - config = _llama3_8b() - - config.max_seq_len = 131_072 - config.use_scaled_rope = True - - return config - - -@llama_arch("llama3_1_70b") -def _llama3_1_70b() -> LLaMAConfig: - config = _llama3_70b() - - config.max_seq_len = 131_072 - config.use_scaled_rope = True - - return config + return config diff --git a/src/fairseq2/models/mistral/__init__.py b/src/fairseq2/models/mistral/__init__.py index a1a634983..78da98531 100644 --- a/src/fairseq2/models/mistral/__init__.py +++ b/src/fairseq2/models/mistral/__init__.py @@ -21,4 +21,9 @@ # isort: split -import fairseq2.models.mistral.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.mistral.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/mistral/archs.py b/src/fairseq2/models/mistral/archs.py index 88ef2f799..d039c3cdd 100644 --- a/src/fairseq2/models/mistral/archs.py +++ b/src/fairseq2/models/mistral/archs.py @@ -9,6 +9,7 @@ from fairseq2.models.mistral.factory import MistralConfig, mistral_arch -@mistral_arch("7b") -def _7b() -> MistralConfig: - return MistralConfig() +def register_archs() -> None: + @mistral_arch("7b") + def _7b() -> MistralConfig: + return MistralConfig() diff --git a/src/fairseq2/models/nllb/__init__.py b/src/fairseq2/models/nllb/__init__.py index b5a0ac7ed..9d95a0819 100644 --- a/src/fairseq2/models/nllb/__init__.py +++ b/src/fairseq2/models/nllb/__init__.py @@ -11,4 +11,9 @@ # isort: split -import fairseq2.models.nllb.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.nllb.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/nllb/archs.py b/src/fairseq2/models/nllb/archs.py index 11750082d..983852d5d 100644 --- a/src/fairseq2/models/nllb/archs.py +++ b/src/fairseq2/models/nllb/archs.py @@ -15,51 +15,49 @@ from fairseq2.nn.transformer import TransformerNormOrder -@transformer_arch("nllb_dense_300m") -def _dense_300m() -> TransformerConfig: - config = _dense_1b() +def register_archs() -> None: + @transformer_arch("nllb_dense_300m") + def _dense_300m() -> TransformerConfig: + config = _dense_1b() - config.num_encoder_layers = 6 - config.num_decoder_layers = 6 - config.ffn_inner_dim = 1024 * 4 - config.dropout_p = 0.3 + config.num_encoder_layers = 6 + config.num_decoder_layers = 6 + config.ffn_inner_dim = 1024 * 4 + config.dropout_p = 0.3 - return config + return config + @transformer_arch("nllb_dense_600m") + def _dense_600m() -> TransformerConfig: + config = _dense_1b() -@transformer_arch("nllb_dense_600m") -def _dense_600m() -> TransformerConfig: - config = _dense_1b() + config.num_encoder_layers = 12 + config.num_decoder_layers = 12 + config.ffn_inner_dim = 1024 * 4 - config.num_encoder_layers = 12 - config.num_decoder_layers = 12 - config.ffn_inner_dim = 1024 * 4 + return config - return config + @transformer_arch("nllb_dense_1b") + def _dense_1b() -> TransformerConfig: + config = transformer_archs.get("base") + config.model_dim = 1024 + config.vocab_info = VocabularyInfo( + size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 + ) + config.num_encoder_layers = 24 + config.num_decoder_layers = 24 + config.num_encoder_attn_heads = 16 + config.num_decoder_attn_heads = 16 + config.ffn_inner_dim = 1024 * 8 + config.norm_order = TransformerNormOrder.PRE -@transformer_arch("nllb_dense_1b") -def _dense_1b() -> TransformerConfig: - config = transformer_archs.get("base") + return config - config.model_dim = 1024 - config.vocab_info = VocabularyInfo( - size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 - ) - config.num_encoder_layers = 24 - config.num_decoder_layers = 24 - config.num_encoder_attn_heads = 16 - config.num_decoder_attn_heads = 16 - config.ffn_inner_dim = 1024 * 8 - config.norm_order = TransformerNormOrder.PRE + @transformer_arch("nllb_dense_3b") + def _dense_3b() -> TransformerConfig: + config = _dense_1b() - return config + config.model_dim = 2048 - -@transformer_arch("nllb_dense_3b") -def _dense_3b() -> TransformerConfig: - config = _dense_1b() - - config.model_dim = 2048 - - return config + return config diff --git a/src/fairseq2/models/s2t_transformer/__init__.py b/src/fairseq2/models/s2t_transformer/__init__.py index 38ed516da..31946daf0 100644 --- a/src/fairseq2/models/s2t_transformer/__init__.py +++ b/src/fairseq2/models/s2t_transformer/__init__.py @@ -45,4 +45,9 @@ # isort: split -import fairseq2.models.s2t_transformer.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.s2t_transformer.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/s2t_transformer/archs.py b/src/fairseq2/models/s2t_transformer/archs.py index 195f85ffb..8e2599bb9 100644 --- a/src/fairseq2/models/s2t_transformer/archs.py +++ b/src/fairseq2/models/s2t_transformer/archs.py @@ -13,69 +13,66 @@ ) -@s2t_transformer_arch("tiny") -def _tiny() -> S2TTransformerConfig: - config = _medium() - - config.model_dim = 256 - config.num_encoder_layers = 6 - config.num_decoder_layers = 3 - config.num_encoder_attn_heads = 4 - config.num_decoder_attn_heads = 4 - config.ffn_inner_dim = 256 * 4 - config.dropout_p = 0.3 - - return config - - -@s2t_transformer_arch("small") -def _small() -> S2TTransformerConfig: - config = _medium() - - config.model_dim = 256 - config.num_encoder_attn_heads = 4 - config.num_decoder_attn_heads = 4 - config.ffn_inner_dim = 256 * 8 - config.dropout_p = 0.1 - - return config - - -@s2t_transformer_arch("medium") -def _medium() -> S2TTransformerConfig: - return S2TTransformerConfig() - - -@s2t_transformer_arch("large") -def _large() -> S2TTransformerConfig: - config = _medium() - - config.model_dim = 1024 - config.num_encoder_attn_heads = 16 - config.num_decoder_attn_heads = 16 - config.ffn_inner_dim = 1024 * 4 - config.dropout_p = 0.2 - - return config - - -@s2t_transformer_arch("conformer_medium") -def _conformer_medium() -> S2TTransformerConfig: - return S2TTransformerConfig( - model_dim=256, - max_source_seq_len=6000, - num_fbank_channels=80, - max_target_seq_len=1024, - target_vocab_info=VocabularyInfo( - size=181, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1 - ), - use_relative_pos=False, - use_conformer=True, - num_encoder_layers=12, - num_decoder_layers=6, - num_encoder_attn_heads=4, - num_decoder_attn_heads=8, - ffn_inner_dim=512 * 4, - dropout_p=0.1, - depthwise_conv_kernel_size=31, - ) +def register_archs() -> None: + @s2t_transformer_arch("tiny") + def _tiny() -> S2TTransformerConfig: + config = _medium() + + config.model_dim = 256 + config.num_encoder_layers = 6 + config.num_decoder_layers = 3 + config.num_encoder_attn_heads = 4 + config.num_decoder_attn_heads = 4 + config.ffn_inner_dim = 256 * 4 + config.dropout_p = 0.3 + + return config + + @s2t_transformer_arch("small") + def _small() -> S2TTransformerConfig: + config = _medium() + + config.model_dim = 256 + config.num_encoder_attn_heads = 4 + config.num_decoder_attn_heads = 4 + config.ffn_inner_dim = 256 * 8 + config.dropout_p = 0.1 + + return config + + @s2t_transformer_arch("medium") + def _medium() -> S2TTransformerConfig: + return S2TTransformerConfig() + + @s2t_transformer_arch("large") + def _large() -> S2TTransformerConfig: + config = _medium() + + config.model_dim = 1024 + config.num_encoder_attn_heads = 16 + config.num_decoder_attn_heads = 16 + config.ffn_inner_dim = 1024 * 4 + config.dropout_p = 0.2 + + return config + + @s2t_transformer_arch("conformer_medium") + def _conformer_medium() -> S2TTransformerConfig: + return S2TTransformerConfig( + model_dim=256, + max_source_seq_len=6000, + num_fbank_channels=80, + max_target_seq_len=1024, + target_vocab_info=VocabularyInfo( + size=181, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1 + ), + use_relative_pos=False, + use_conformer=True, + num_encoder_layers=12, + num_decoder_layers=6, + num_encoder_attn_heads=4, + num_decoder_attn_heads=8, + ffn_inner_dim=512 * 4, + dropout_p=0.1, + depthwise_conv_kernel_size=31, + ) diff --git a/src/fairseq2/models/transformer/__init__.py b/src/fairseq2/models/transformer/__init__.py index 287fe9404..8df94623a 100644 --- a/src/fairseq2/models/transformer/__init__.py +++ b/src/fairseq2/models/transformer/__init__.py @@ -39,4 +39,9 @@ # isort: split -import fairseq2.models.transformer.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.transformer.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/transformer/archs.py b/src/fairseq2/models/transformer/archs.py index a7fff374a..a96d2b8e2 100644 --- a/src/fairseq2/models/transformer/archs.py +++ b/src/fairseq2/models/transformer/archs.py @@ -9,19 +9,19 @@ from fairseq2.models.transformer.factory import TransformerConfig, transformer_arch -@transformer_arch("base") -def _base() -> TransformerConfig: - return TransformerConfig() - - -@transformer_arch("big") -def _big() -> TransformerConfig: - config = TransformerConfig() - - config.model_dim = 1024 - config.num_encoder_attn_heads = 16 - config.num_decoder_attn_heads = 16 - config.ffn_inner_dim = 4096 - config.dropout_p = 0.3 - - return config +def register_archs() -> None: + @transformer_arch("base") + def _base() -> TransformerConfig: + return TransformerConfig() + + @transformer_arch("big") + def _big() -> TransformerConfig: + config = TransformerConfig() + + config.model_dim = 1024 + config.num_encoder_attn_heads = 16 + config.num_decoder_attn_heads = 16 + config.ffn_inner_dim = 4096 + config.dropout_p = 0.3 + + return config diff --git a/src/fairseq2/models/w2vbert/__init__.py b/src/fairseq2/models/w2vbert/__init__.py index 2ca877d6a..f343d0f8e 100644 --- a/src/fairseq2/models/w2vbert/__init__.py +++ b/src/fairseq2/models/w2vbert/__init__.py @@ -20,4 +20,9 @@ # isort: split -import fairseq2.models.w2vbert.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.w2vbert.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/w2vbert/archs.py b/src/fairseq2/models/w2vbert/archs.py index 2bc63bec9..c77110537 100644 --- a/src/fairseq2/models/w2vbert/archs.py +++ b/src/fairseq2/models/w2vbert/archs.py @@ -10,31 +10,29 @@ from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig, wav2vec2_encoder_arch -@w2vbert_arch("600m") -def _600m() -> W2VBertConfig: - return W2VBertConfig() +def register_archs() -> None: + @w2vbert_arch("600m") + def _600m() -> W2VBertConfig: + return W2VBertConfig() + @w2vbert_arch("300m") + def _300m() -> W2VBertConfig: + config = _600m() -@w2vbert_arch("300m") -def _300m() -> W2VBertConfig: - config = _600m() + config.w2v2_config.encoder_config.num_encoder_layers = 12 - config.w2v2_config.encoder_config.num_encoder_layers = 12 + config.num_bert_encoder_layers = 8 - config.num_bert_encoder_layers = 8 + return config - return config + @wav2vec2_encoder_arch("bert_600m") + def _600m_encoder() -> Wav2Vec2EncoderConfig: + config = _600m() + return config.w2v2_config.encoder_config -@wav2vec2_encoder_arch("bert_600m") -def _600m_encoder() -> Wav2Vec2EncoderConfig: - config = _600m() + @wav2vec2_encoder_arch("bert_300m") + def _300m_encoder() -> Wav2Vec2EncoderConfig: + config = _300m() - return config.w2v2_config.encoder_config - - -@wav2vec2_encoder_arch("bert_300m") -def _300m_encoder() -> Wav2Vec2EncoderConfig: - config = _300m() - - return config.w2v2_config.encoder_config + return config.w2v2_config.encoder_config diff --git a/src/fairseq2/models/wav2vec2/__init__.py b/src/fairseq2/models/wav2vec2/__init__.py index f66277142..75aefa9a2 100644 --- a/src/fairseq2/models/wav2vec2/__init__.py +++ b/src/fairseq2/models/wav2vec2/__init__.py @@ -49,4 +49,9 @@ # isort: split -import fairseq2.models.wav2vec2.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.wav2vec2.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/wav2vec2/archs.py b/src/fairseq2/models/wav2vec2/archs.py index 10b7c8de6..2cc686312 100644 --- a/src/fairseq2/models/wav2vec2/archs.py +++ b/src/fairseq2/models/wav2vec2/archs.py @@ -15,107 +15,102 @@ from fairseq2.nn.transformer import TransformerNormOrder -@wav2vec2_arch("base") -def _base() -> Wav2Vec2Config: - return Wav2Vec2Config() - - -@wav2vec2_arch("large") -def _large() -> Wav2Vec2Config: - config = _base() - - config.encoder_config.model_dim = 1024 - config.encoder_config.num_encoder_layers = 24 - config.encoder_config.num_encoder_attn_heads = 16 - config.encoder_config.ffn_inner_dim = 4096 - config.encoder_config.dropout_p = 0.0 - config.encoder_config.layer_drop_p = 0.2 - config.quantized_dim = 768 - config.final_dim = 768 - - return config - - -@wav2vec2_arch("large_lv60k") # LibriVox 60k -def _large_lv60k() -> Wav2Vec2Config: - config = _large() - - config.encoder_config.layer_norm_features = False - config.encoder_config.feature_extractor_bias = True - config.encoder_config.feature_extractor_layer_norm_convs = True - config.encoder_config.layer_drop_p = 0.0 - config.encoder_config.norm_order = TransformerNormOrder.PRE - config.codebook_sampling_temperature = (2.0, 0.1, 0.999995) - - return config - - -@wav2vec2_arch("pseudo_dinosr_base") -def _pseudo_dinosr_base() -> Wav2Vec2Config: - layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 3 - - encoder_config = Wav2Vec2EncoderConfig( - model_dim=768, - max_seq_len=100000, - feature_dim=512, - use_fbank=False, - first_pass_dropout_p=0.0, - layer_norm_features=True, - feature_extractor_layer_descs=layer_descs, - feature_extractor_bias=False, - feature_extractor_layer_norm_convs=True, - feature_gradient_scale=0.1, - num_fbank_channels=0, - fbank_stride=0, - sample_fbank_every_k=0, - pos_encoder_type="conv", - pos_encoder_depth=5, - pos_conv_kernel_size=95, - num_pos_conv_groups=16, - use_conformer=False, - num_encoder_layers=12, - num_encoder_attn_heads=12, - ffn_inner_dim=3072, - dropout_p=0.1, - attn_dropout_p=0.1, - layer_drop_p=0.0, - norm_order=TransformerNormOrder.POST, - depthwise_conv_kernel_size=31, - ) - - return Wav2Vec2Config( - encoder_config=encoder_config, - final_dim=256, - final_proj_bias=True, - temporal_mask_span_len=10, - max_temporal_mask_prob=0.65, - spatial_mask_span_len=10, - max_spatial_mask_prob=0.0, - quantized_dim=256, - num_codebooks=2, - num_codebook_entries=320, - codebook_sampling_temperature=(2.0, 0.5, 0.999995), - num_distractors=100, - logit_temp=0.1, - ) - - -@wav2vec2_encoder_arch("base") -def _base_encoder() -> Wav2Vec2EncoderConfig: - config = _base() - - return config.encoder_config - - -@wav2vec2_encoder_arch("large") -def _large_encoder() -> Wav2Vec2EncoderConfig: - config = _large() - - return config.encoder_config - - -@wav2vec2_encoder_arch("large_lv60k") # LibriVox 60k -def _large_lv60k_encoder() -> Wav2Vec2EncoderConfig: - config = _large_lv60k() - - return config.encoder_config +def register_archs() -> None: + @wav2vec2_arch("base") + def _base() -> Wav2Vec2Config: + return Wav2Vec2Config() + + @wav2vec2_arch("large") + def _large() -> Wav2Vec2Config: + config = _base() + + config.encoder_config.model_dim = 1024 + config.encoder_config.num_encoder_layers = 24 + config.encoder_config.num_encoder_attn_heads = 16 + config.encoder_config.ffn_inner_dim = 4096 + config.encoder_config.dropout_p = 0.0 + config.encoder_config.layer_drop_p = 0.2 + config.quantized_dim = 768 + config.final_dim = 768 + + return config + + @wav2vec2_arch("large_lv60k") # LibriVox 60k + def _large_lv60k() -> Wav2Vec2Config: + config = _large() + + config.encoder_config.layer_norm_features = False + config.encoder_config.feature_extractor_bias = True + config.encoder_config.feature_extractor_layer_norm_convs = True + config.encoder_config.layer_drop_p = 0.0 + config.encoder_config.norm_order = TransformerNormOrder.PRE + config.codebook_sampling_temperature = (2.0, 0.1, 0.999995) + + return config + + @wav2vec2_arch("pseudo_dinosr_base") + def _pseudo_dinosr_base() -> Wav2Vec2Config: + layer_descs = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 3 + + encoder_config = Wav2Vec2EncoderConfig( + model_dim=768, + max_seq_len=100000, + feature_dim=512, + use_fbank=False, + first_pass_dropout_p=0.0, + layer_norm_features=True, + feature_extractor_layer_descs=layer_descs, + feature_extractor_bias=False, + feature_extractor_layer_norm_convs=True, + feature_gradient_scale=0.1, + num_fbank_channels=0, + fbank_stride=0, + sample_fbank_every_k=0, + pos_encoder_type="conv", + pos_encoder_depth=5, + pos_conv_kernel_size=95, + num_pos_conv_groups=16, + use_conformer=False, + num_encoder_layers=12, + num_encoder_attn_heads=12, + ffn_inner_dim=3072, + dropout_p=0.1, + attn_dropout_p=0.1, + layer_drop_p=0.0, + norm_order=TransformerNormOrder.POST, + depthwise_conv_kernel_size=31, + ) + + return Wav2Vec2Config( + encoder_config=encoder_config, + final_dim=256, + final_proj_bias=True, + temporal_mask_span_len=10, + max_temporal_mask_prob=0.65, + spatial_mask_span_len=10, + max_spatial_mask_prob=0.0, + quantized_dim=256, + num_codebooks=2, + num_codebook_entries=320, + codebook_sampling_temperature=(2.0, 0.5, 0.999995), + num_distractors=100, + logit_temp=0.1, + ) + + @wav2vec2_encoder_arch("base") + def _base_encoder() -> Wav2Vec2EncoderConfig: + config = _base() + + return config.encoder_config + + @wav2vec2_encoder_arch("large") + def _large_encoder() -> Wav2Vec2EncoderConfig: + config = _large() + + return config.encoder_config + + @wav2vec2_encoder_arch("large_lv60k") # LibriVox 60k + def _large_lv60k_encoder() -> Wav2Vec2EncoderConfig: + config = _large_lv60k() + + return config.encoder_config diff --git a/src/fairseq2/models/wav2vec2/asr/__init__.py b/src/fairseq2/models/wav2vec2/asr/__init__.py index 7d89fc960..03b1b7aad 100644 --- a/src/fairseq2/models/wav2vec2/asr/__init__.py +++ b/src/fairseq2/models/wav2vec2/asr/__init__.py @@ -31,4 +31,9 @@ # isort: split -import fairseq2.models.wav2vec2.asr.archs # Register architectures. +from fairseq2.dependency import DependencyContainer +from fairseq2.models.wav2vec2.asr.archs import register_archs + + +def register_objects(container: DependencyContainer) -> None: + register_archs() diff --git a/src/fairseq2/models/wav2vec2/asr/archs.py b/src/fairseq2/models/wav2vec2/asr/archs.py index ed8e81800..8cb65ccdc 100644 --- a/src/fairseq2/models/wav2vec2/asr/archs.py +++ b/src/fairseq2/models/wav2vec2/asr/archs.py @@ -10,69 +10,65 @@ from fairseq2.models.wav2vec2.factory import wav2vec2_encoder_archs -@wav2vec2_asr_arch("base_10h") -def _base_10h() -> Wav2Vec2AsrConfig: - return Wav2Vec2AsrConfig() +def register_archs() -> None: + @wav2vec2_asr_arch("base_10h") + def _base_10h() -> Wav2Vec2AsrConfig: + return Wav2Vec2AsrConfig() + @wav2vec2_asr_arch("base_100h") + def _base_100h() -> Wav2Vec2AsrConfig: + config = _base_10h() -@wav2vec2_asr_arch("base_100h") -def _base_100h() -> Wav2Vec2AsrConfig: - config = _base_10h() + config.encoder_config.layer_drop_p = 0.1 - config.encoder_config.layer_drop_p = 0.1 + return config - return config + @wav2vec2_asr_arch("large_10h") + def _large_10h() -> Wav2Vec2AsrConfig: + config = _base_10h() + config.encoder_config = wav2vec2_encoder_archs.get("large") + config.encoder_config.feature_gradient_scale = 1.0 + config.encoder_config.dropout_p = 0.0 + config.encoder_config.attn_dropout_p = 0.0 + config.encoder_config.ffn_inner_dropout_p = 0.1 + config.encoder_config.layer_drop_p = 0.1 -@wav2vec2_asr_arch("large_10h") -def _large_10h() -> Wav2Vec2AsrConfig: - config = _base_10h() + config.max_temporal_mask_prob = 0.80 + config.max_spatial_mask_prob = 0.30 - config.encoder_config = wav2vec2_encoder_archs.get("large") - config.encoder_config.feature_gradient_scale = 1.0 - config.encoder_config.dropout_p = 0.0 - config.encoder_config.attn_dropout_p = 0.0 - config.encoder_config.ffn_inner_dropout_p = 0.1 - config.encoder_config.layer_drop_p = 0.1 + return config - config.max_temporal_mask_prob = 0.80 - config.max_spatial_mask_prob = 0.30 + @wav2vec2_asr_arch("large_100h") + def _large_100h() -> Wav2Vec2AsrConfig: + config = _large_10h() - return config + config.max_temporal_mask_prob = 0.53 + config.max_spatial_mask_prob = 0.55 + return config -@wav2vec2_asr_arch("large_100h") -def _large_100h() -> Wav2Vec2AsrConfig: - config = _large_10h() + @wav2vec2_asr_arch("large_lv60k_10h") + def _large_lv60k_10h() -> Wav2Vec2AsrConfig: + config = _base_10h() - config.max_temporal_mask_prob = 0.53 - config.max_spatial_mask_prob = 0.55 + config.encoder_config = wav2vec2_encoder_archs.get("large_lv60k") + config.encoder_config.feature_gradient_scale = 1.0 + config.encoder_config.dropout_p = 0.0 + config.encoder_config.attn_dropout_p = 0.0 + config.encoder_config.ffn_inner_dropout_p = 0.1 + config.encoder_config.layer_drop_p = 0.1 - return config + config.max_temporal_mask_prob = 0.80 + config.max_spatial_mask_prob = 0.30 + return config -@wav2vec2_asr_arch("large_lv60k_10h") -def _large_lv60k_10h() -> Wav2Vec2AsrConfig: - config = _base_10h() + @wav2vec2_asr_arch("large_lv60k_100h") + def _large_lv60k_100h() -> Wav2Vec2AsrConfig: + config = _large_lv60k_10h() - config.encoder_config = wav2vec2_encoder_archs.get("large_lv60k") - config.encoder_config.feature_gradient_scale = 1.0 - config.encoder_config.dropout_p = 0.0 - config.encoder_config.attn_dropout_p = 0.0 - config.encoder_config.ffn_inner_dropout_p = 0.1 - config.encoder_config.layer_drop_p = 0.1 + config.max_temporal_mask_prob = 0.53 + config.max_spatial_mask_prob = 0.55 - config.max_temporal_mask_prob = 0.80 - config.max_spatial_mask_prob = 0.30 - - return config - - -@wav2vec2_asr_arch("large_lv60k_100h") -def _large_lv60k_100h() -> Wav2Vec2AsrConfig: - config = _large_lv60k_10h() - - config.max_temporal_mask_prob = 0.53 - config.max_spatial_mask_prob = 0.55 - - return config + return config diff --git a/src/fairseq2/setup.py b/src/fairseq2/setup.py index 304b2365f..7d36d8171 100644 --- a/src/fairseq2/setup.py +++ b/src/fairseq2/setup.py @@ -65,6 +65,14 @@ def _setup_library(container: DependencyContainer) -> None: "fairseq2.assets.metadata_provider", "fairseq2.assets.store", "fairseq2.device", + "fairseq2.models.llama", + "fairseq2.models.mistral", + "fairseq2.models.nllb", + "fairseq2.models.s2t_transformer", + "fairseq2.models.transformer", + "fairseq2.models.w2vbert", + "fairseq2.models.wav2vec2", + "fairseq2.models.wav2vec2.asr", "fairseq2.recipes.utils.environment", "fairseq2.utils.structured", ]