Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce model_config in finetuning recipes #815

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/fairseq2/models/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataClass
from fairseq2.utils.dataclass import fill_empty_fields
from fairseq2.utils.structured import (
StructuredError,
ValueConverter,
Expand All @@ -33,7 +34,9 @@
class ModelConfigLoader(Protocol[ModelConfigT_co]):
"""Loads model configurations of type ``ModelConfigT``."""

def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT_co:
def __call__(
self, model_name_or_card: str | AssetCard, unstructured_config: object = None
) -> ModelConfigT_co:
"""
:param model_name_or_card:
The name or the asset card of the model whole configuration to load.
Expand Down Expand Up @@ -78,7 +81,9 @@ def __init__(
self._arch_configs = arch_configs
self._value_converter = value_converter

def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT:
def __call__(
self, model_name_or_card: str | AssetCard, unstructured_config: object = None
) -> ModelConfigT:
if isinstance(model_name_or_card, AssetCard):
card = model_name_or_card
else:
Expand All @@ -103,7 +108,7 @@ def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT:
# Load the configuration.
if arch is None:
try:
config = config_kls()
base_config = config_kls()
except TypeError as ex:
raise AssetError(
f"The '{self._family}' model family has no default configuration."
Expand All @@ -115,50 +120,61 @@ def __call__(self, model_name_or_card: str | AssetCard) -> ModelConfigT:
)

try:
config = self._arch_configs.get(arch)
base_config = self._arch_configs.get(arch)
except ValueError:
raise AssetError(
f"The '{self._family}' model family has no architecture named '{arch}'."
) from None

# Override the default architecture configuration if needed.
config_overrides_list = []
if self._value_converter is None:
self._value_converter = get_value_converter()

model_config_fields = []

card_: AssetCard | None = card

while card_ is not None:
if "model_config" in card_.metadata:
config_overrides = card_.field("model_config").as_unstructured()
model_config_field = card_.field("model_config").as_unstructured()

config_overrides_list.append(config_overrides)
model_config_fields.append(model_config_field)

card_ = card_.base

if config_overrides_list:
if self._value_converter is None:
self._value_converter = get_value_converter()

if model_config_fields:
try:
unstructured_config = self._value_converter.unstructure(config)
unstructured_base_config = self._value_converter.unstructure(
base_config
)
except StructuredError as ex:
raise AssetError(
f"The model configuration class of the '{self._family}' cannot be used. Please file a bug report to the model author."
) from ex

try:
for config_overrides in reversed(config_overrides_list):
unstructured_config = merge_unstructured(
unstructured_config, config_overrides
for model_config_field in reversed(model_config_fields):
unstructured_base_config = merge_unstructured(
unstructured_base_config, model_config_field
)

config = self._value_converter.structure(
unstructured_config, type_expr=config_kls
base_config = self._value_converter.structure(
unstructured_base_config, type_expr=config_kls
)
except StructuredError as ex:
raise AssetError(
f"The value of the field 'model_config' of the asset card '{card.name}' cannot be parsed as a valid model configuration. Please file a bug report to the asset author."
) from ex

if unstructured_config is None:
config = base_config
else:
config = self._value_converter.structure(
unstructured_config, config_kls, allow_empty=True
)

fill_empty_fields(config, base_config)

return config


Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/llama/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class LLaMAConfig:
use_scaled_rope: bool = False
"""If ``True``, scales Rotary encoding frequencies to LLaMA 3.1 context length."""

dropout_p: float = 0.0 # TODO: Revert back to 0.1
dropout_p: float = 0.1
"""The dropout probability on outputs of Transformer layers."""


Expand Down
6 changes: 5 additions & 1 deletion src/fairseq2/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __call__(
model_name_or_card: str | AssetCard,
*,
gangs: Mapping[str, Gang] | None = None,
unstructured_config: object = None,
device: Device | None = None,
dtype: DataType | None = None,
force: bool = False,
Expand Down Expand Up @@ -195,6 +196,7 @@ def __call__(
model_name_or_card: str | AssetCard,
*,
gangs: Mapping[str, Gang] | None = None,
unstructured_config: object = None,
device: Device | None = None,
dtype: DataType | None = None,
force: bool = False,
Expand Down Expand Up @@ -246,7 +248,7 @@ def __call__(

model = None

config = self._config_loader(card)
config = self._config_loader(card, unstructured_config)

if device.type == "meta":
try:
Expand Down Expand Up @@ -392,6 +394,7 @@ def __call__(
model_name_or_card: str | AssetCard,
*,
gangs: Mapping[str, Gang] | None = None,
unstructured_config: object = None,
device: Device | None = None,
dtype: DataType | None = None,
force: bool = False,
Expand All @@ -417,6 +420,7 @@ def __call__(
return loader(
model_name_or_card,
gangs=gangs,
unstructured_config=unstructured_config,
device=device,
dtype=dtype,
force=force,
Expand Down
28 changes: 25 additions & 3 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.llama import llama_archs
from fairseq2.models.sequence import (
SequenceBatch,
SequenceModelOutput,
Expand All @@ -53,6 +54,7 @@
to_data_parallel,
)
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -83,9 +85,17 @@ class InstructionFinetuneConfig:
"""The number of batches to prefetch in background."""

# Model
model: AssetReference = "llama3_8b_instruct"
model: AssetReference = "llama3_1_8b_instruct"
"""The name or path to the asset card of the language model to finetune."""

model_config: Any = field(
default_factory=lambda: empty_(llama_archs.get("llama3_1_8b"))
)
"""
The model configuration overrides. The provided values must be compatible
with the checkpoint; otherwise, the model will fail to load.
"""

dtype: DataType = torch.bfloat16
"""The data type of the model."""

Expand Down Expand Up @@ -303,7 +313,13 @@ def load_instruction_finetuner(

if has_checkpoint:
try:
model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype)
model = load_model(
model_card,
gangs=gangs,
unstructured_config=config.model_config,
device=init_device,
dtype=dtype,
)
except ValueError as ex:
raise ValueError(
"The model cannot be initialized. See nested exception for details."
Expand All @@ -317,7 +333,13 @@ def load_instruction_finetuner(
init_device = root_gang.device

try:
model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype)
model = load_model(
model_card,
gangs=gangs,
unstructured_config=config.model_config,
device=init_device,
dtype=dtype,
)
except ValueError as ex:
raise ValueError(
"The model cannot be initialized. See nested exception for details."
Expand Down
28 changes: 25 additions & 3 deletions src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.llama import llama_archs
from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing
from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa
from fairseq2.optim import AdamWConfig, create_optimizer
Expand All @@ -42,6 +43,7 @@
from fairseq2.recipes.utils.log import log_model
from fairseq2.recipes.utils.setup import compile_model, setup_gangs, to_data_parallel
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -79,9 +81,17 @@ class PreferenceOptimizationConfig:
"""If ``False``, calculates loss on the `src` tokens as well as the `tgt` tokens."""

# Model
model: AssetReference = "llama3_8b_instruct"
model: AssetReference = "llama3_1_8b_instruct"
"""The name or path to the asset card of the language model to finetune."""

model_config: Any = field(
default_factory=lambda: empty_(llama_archs.get("llama3_1_8b"))
)
"""
The model configuration overrides. The provided values must be compatible
with the checkpoint; otherwise, the model will fail to load.
"""

dtype: DataType = torch.bfloat16
"""The data type of the model."""

Expand Down Expand Up @@ -289,7 +299,13 @@ def load_preference_finetuner(

if has_checkpoint:
try:
model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype)
model = load_model(
model_card,
gangs=gangs,
unstructured_config=config.model_config,
device=init_device,
dtype=dtype,
)
except ValueError as ex:
raise ValueError(
"The model cannot be initialized. See nested exception for details."
Expand All @@ -302,7 +318,13 @@ def load_preference_finetuner(
if dp_gang.rank == 0:
init_device = root_gang.device

model = load_model(model_card, gangs=gangs, device=init_device, dtype=dtype)
model = load_model(
model_card,
gangs=gangs,
unstructured_config=config.model_config,
device=init_device,
dtype=dtype,
)

root_gang.barrier()

Expand Down
Loading