diff --git a/src/fairseq2/models/config_loader.py b/src/fairseq2/models/config_loader.py index 2e4e6ff85..bcefbbbd2 100644 --- a/src/fairseq2/models/config_loader.py +++ b/src/fairseq2/models/config_loader.py @@ -18,6 +18,7 @@ ) from fairseq2.config_registry import ConfigRegistry from fairseq2.typing import DataClass +from fairseq2.utils.dataclass import fill_empty_fields from fairseq2.utils.structured import ( StructuredError, ValueConverter, @@ -33,7 +34,9 @@ class ModelConfigLoader(Protocol[ModelConfigT_co]): """Loads model configurations of type ``ModelConfigT``.""" - def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT_co: + def __call__( + self, model_name_or_card: str | AssetCard, unstructured_config: object = None + ) -> ModelConfigT_co: """ :param model_name_or_card: The name or the asset card of the model whole configuration to load. @@ -78,7 +81,9 @@ def __init__( self._arch_configs = arch_configs self._value_converter = value_converter - def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT: + def __call__( + self, model_name_or_card: str | AssetCard, unstructured_config: object = None + ) -> ModelConfigT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card else: @@ -103,7 +108,7 @@ def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT: # Load the configuration. if arch is None: try: - config = config_kls() + base_config = config_kls() except TypeError as ex: raise AssetError( f"The '{self._family}' model family has no default configuration." @@ -115,50 +120,61 @@ def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT: ) try: - config = self._arch_configs.get(arch) + base_config = self._arch_configs.get(arch) except ValueError: raise AssetError( f"The '{self._family}' model family has no architecture named '{arch}'." ) from None # Override the default architecture configuration if needed. - config_overrides_list = [] + if self._value_converter is None: + self._value_converter = get_value_converter() + + model_config_fields = [] card_: AssetCard | None = card while card_ is not None: if "model_config" in card_.metadata: - config_overrides = card_.field("model_config").as_unstructured() + model_config_field = card_.field("model_config").as_unstructured() - config_overrides_list.append(config_overrides) + model_config_fields.append(model_config_field) card_ = card_.base - if config_overrides_list: - if self._value_converter is None: - self._value_converter = get_value_converter() - + if model_config_fields: try: - unstructured_config = self._value_converter.unstructure(config) + unstructured_base_config = self._value_converter.unstructure( + base_config + ) except StructuredError as ex: raise AssetError( f"The model configuration class of the '{self._family}' cannot be used. Please file a bug report to the model author." ) from ex try: - for config_overrides in reversed(config_overrides_list): - unstructured_config = merge_unstructured( - unstructured_config, config_overrides + for model_config_field in reversed(model_config_fields): + unstructured_base_config = merge_unstructured( + unstructured_base_config, model_config_field ) - config = self._value_converter.structure( - unstructured_config, type_expr=config_kls + base_config = self._value_converter.structure( + unstructured_base_config, type_expr=config_kls ) except StructuredError as ex: raise AssetError( f"The value of the field 'model_config' of the asset card '{card.name}' cannot be parsed as a valid model configuration. Please file a bug report to the asset author." ) from ex + if unstructured_config is None: + config = base_config + else: + config = self._value_converter.structure( + unstructured_config, config_kls, allow_empty=True + ) + + fill_empty_fields(config, base_config) + return config diff --git a/src/fairseq2/models/llama/factory.py b/src/fairseq2/models/llama/factory.py index 4020f8267..5b81ee538 100644 --- a/src/fairseq2/models/llama/factory.py +++ b/src/fairseq2/models/llama/factory.py @@ -88,7 +88,7 @@ class LLaMAConfig: use_scaled_rope: bool = False """If ``True``, scales Rotary encoding frequencies to LLaMA 3.1 context length.""" - dropout_p: float = 0.0 # TODO: Revert back to 0.1 + dropout_p: float = 0.1 """The dropout probability on outputs of Transformer layers.""" diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 3debb5315..2e56f1d64 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -77,6 +77,7 @@ def __call__( model_name_or_card: str | AssetCard, *, gangs: Mapping[str, Gang] | None = None, + unstructured_config: object = None, device: Device | None = None, dtype: DataType | None = None, force: bool = False, @@ -195,6 +196,7 @@ def __call__( model_name_or_card: str | AssetCard, *, gangs: Mapping[str, Gang] | None = None, + unstructured_config: object = None, device: Device | None = None, dtype: DataType | None = None, force: bool = False, @@ -246,7 +248,7 @@ def __call__( model = None - config = self._config_loader(card) + config = self._config_loader(card, unstructured_config) if device.type == "meta": try: @@ -392,6 +394,7 @@ def __call__( model_name_or_card: str | AssetCard, *, gangs: Mapping[str, Gang] | None = None, + unstructured_config: object = None, device: Device | None = None, dtype: DataType | None = None, force: bool = False, @@ -417,6 +420,7 @@ def __call__( return loader( model_name_or_card, gangs=gangs, + unstructured_config=unstructured_config, device=device, dtype=dtype, force=force, diff --git a/src/fairseq2/recipes/lm/instruction_finetune.py b/src/fairseq2/recipes/lm/instruction_finetune.py index f85c93d5a..cdd6afecf 100644 --- a/src/fairseq2/recipes/lm/instruction_finetune.py +++ b/src/fairseq2/recipes/lm/instruction_finetune.py @@ -29,6 +29,7 @@ from fairseq2.logging import get_log_writer from fairseq2.models import load_model from fairseq2.models.decoder import DecoderModel +from fairseq2.models.llama import llama_archs from fairseq2.models.sequence import ( SequenceBatch, SequenceModelOutput, @@ -53,6 +54,7 @@ to_data_parallel, ) from fairseq2.typing import CPU, META, DataType +from fairseq2.utils.dataclass import empty_ from fairseq2.utils.profiler import Stopwatch from fairseq2.utils.rng import manual_seed @@ -83,9 +85,17 @@ class InstructionFinetuneConfig: """The number of batches to prefetch in background.""" # Model - model: AssetReference = "llama3_8b_instruct" + model: AssetReference = "llama3_1_8b_instruct" """The name or path to the asset card of the language model to finetune.""" + model_config: Any = field( + default_factory=lambda: empty_(llama_archs.get("llama3_1_8b")) + ) + """ + The model configuration overrides. The provided values must be compatible + with the checkpoint; otherwise, the model will fail to load. + """ + dtype: DataType = torch.bfloat16 """The data type of the model.""" @@ -303,7 +313,13 @@ def load_instruction_finetuner( if has_checkpoint: try: - model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype) + model = load_model( + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, + ) except ValueError as ex: raise ValueError( "The model cannot be initialized. See nested exception for details." @@ -317,7 +333,13 @@ def load_instruction_finetuner( init_device = root_gang.device try: - model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype) + model = load_model( + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, + ) except ValueError as ex: raise ValueError( "The model cannot be initialized. See nested exception for details." diff --git a/src/fairseq2/recipes/lm/preference_finetune/recipe.py b/src/fairseq2/recipes/lm/preference_finetune/recipe.py index 02be68e20..f842a6765 100644 --- a/src/fairseq2/recipes/lm/preference_finetune/recipe.py +++ b/src/fairseq2/recipes/lm/preference_finetune/recipe.py @@ -29,6 +29,7 @@ from fairseq2.logging import get_log_writer from fairseq2.models import load_model from fairseq2.models.decoder import DecoderModel +from fairseq2.models.llama import llama_archs from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa from fairseq2.optim import AdamWConfig, create_optimizer @@ -42,6 +43,7 @@ from fairseq2.recipes.utils.log import log_model from fairseq2.recipes.utils.setup import compile_model, setup_gangs, to_data_parallel from fairseq2.typing import CPU, META, DataType +from fairseq2.utils.dataclass import empty_ from fairseq2.utils.profiler import Stopwatch from fairseq2.utils.rng import manual_seed @@ -79,9 +81,17 @@ class PreferenceOptimizationConfig: """If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens.""" # Model - model: AssetReference = "llama3_8b_instruct" + model: AssetReference = "llama3_1_8b_instruct" """The name or path to the asset card of the language model to finetune.""" + model_config: Any = field( + default_factory=lambda: empty_(llama_archs.get("llama3_1_8b")) + ) + """ + The model configuration overrides. The provided values must be compatible + with the checkpoint; otherwise, the model will fail to load. + """ + dtype: DataType = torch.bfloat16 """The data type of the model.""" @@ -289,7 +299,13 @@ def load_preference_finetuner( if has_checkpoint: try: - model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype) + model = load_model( + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, + ) except ValueError as ex: raise ValueError( "The model cannot be initialized. See nested exception for details." @@ -302,7 +318,13 @@ def load_preference_finetuner( if dp_gang.rank == 0: init_device = root_gang.device - model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype) + model = load_model( + model_card, + gangs=gangs, + unstructured_config=config.model_config, + device=init_device, + dtype=dtype, + ) root_gang.barrier()