Skip to content

Commit

Permalink
llms recipe config updates (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
uralik authored Sep 30, 2024
1 parent e1fe1f7 commit 8ce7b48
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 93 deletions.
8 changes: 4 additions & 4 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def dynamic_bucket(
self,
threshold: float,
cost_fn: Callable[[Any], float],
bucket_creation_fn: Callable[
[Sequence[Any]], tuple[Sequence[Sequence[Any]], Sequence[Any]]
]
| None = None,
bucket_creation_fn: (
Callable[[Sequence[Any]], tuple[Sequence[Sequence[Any]], Sequence[Any]]]
| None
) = None,
min_num_examples: int | None = None,
max_num_examples: int | None = None,
drop_remainder: bool = False,
Expand Down
8 changes: 5 additions & 3 deletions src/fairseq2/data/parquet/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def inner_iterator(wrap_table: _TableWrapper) -> DataPipeline:
columns=config.columns,
split_to_row_groups=config.split_to_row_groups,
filesystem=config.filesystem,
shuffle_window=2 * config.nb_prefetch * config.nb_parallel_fragments
if config.shuffle
else None,
shuffle_window=(
2 * config.nb_prefetch * config.nb_parallel_fragments
if config.shuffle
else None
),
seed=config.seed,
)
.shard(shard_idx=config.rank, num_shards=config.world_size)
Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/recipes/lm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _setup_lm_cli(cli: Cli) -> None:
instruction_finetune_handler = RecipeCommandHandler(
loader=load_instruction_finetuner,
preset_configs=instruction_finetune_presets,
default_preset="llama3_8b_instruct",
default_preset="llama3_1_instruct",
)

group.add_command(
Expand All @@ -46,7 +46,7 @@ def _setup_lm_cli(cli: Cli) -> None:
preference_finetune_handler = RecipeCommandHandler(
loader=load_preference_finetuner,
preset_configs=preference_finetune_presets,
default_preset="llama3_8b_instruct",
default_preset="llama3_1_instruct",
)

group.add_command(
Expand All @@ -59,7 +59,7 @@ def _setup_lm_cli(cli: Cli) -> None:
text_generate_handler = RecipeCommandHandler(
loader=load_text_generator,
preset_configs=text_generate_presets,
default_preset="llama3_8b_instruct",
default_preset="llama3_1_8b_instruct",
)

group.add_command(
Expand Down
59 changes: 28 additions & 31 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class InstructionFinetuneConfig:

# Regime
max_num_steps: int = 5000
"""The maximum number of steps to train for."""
"""The maximum number of steps to train for. Note that max_num_steps is used as CosineLRScheduler argument!"""

max_num_data_epochs: int | None = None
"""The maximum number of data epochs to train for."""
Expand Down Expand Up @@ -214,57 +214,54 @@ class InstructionFinetuneConfig:
instruction_finetune_preset = instruction_finetune_presets.decorator


@instruction_finetune_preset("llama2_7b_chat")
def _llama2_7b_chat() -> InstructionFinetuneConfig:
config = _llama3_8b_instruct()
@dataclass(kw_only=True)
class DropoutConfig:
dropout_p: float = 0.0

config.max_seq_len = 4096
config.max_num_tokens = 4096 * 2
config.max_num_valid_tokens = 4096 * 2
config.model = "llama2_7b_chat"

@instruction_finetune_preset("llama3_1_instruct")
def _llama3_1_instruct() -> InstructionFinetuneConfig:
config = InstructionFinetuneConfig()
config.model_config = DropoutConfig()
return config


@instruction_finetune_preset("llama2_70b_chat")
def _llama2_70b_chat() -> InstructionFinetuneConfig:
config = _llama2_7b_chat()

config.model = "llama2_70b_chat"
config.tensor_parallel_size = 8

@instruction_finetune_preset("llama3_1_instruct_constant_lr")
def _llama3_1_instruct_constant_lr() -> InstructionFinetuneConfig:
config = _llama3_1_instruct()
# setting up final lr to be the optmiizer base lr, lr_mul is 1.0 by default
config.lr_scheduler_config.final_lr = config.optimizer_config.lr
return config


@instruction_finetune_preset("llama3_8b_instruct")
def _llama3_8b_instruct() -> InstructionFinetuneConfig:
return InstructionFinetuneConfig()


@instruction_finetune_preset("llama3_70b_instruct")
@instruction_finetune_preset("llama3_1_70b_instruct")
def _llama3_70b_instruct() -> InstructionFinetuneConfig:
config = _llama3_8b_instruct()
config = _llama3_1_instruct()

config.model = "llama3_70b_instruct"
config.model = "llama3_1_70b_instruct"
config.tensor_parallel_size = 8

return config


@instruction_finetune_preset("llama3_1_8b_instruct")
def _llama3_1_8b_instruct() -> InstructionFinetuneConfig:
config = _llama3_8b_instruct()
@instruction_finetune_preset("llama2_7b_chat")
def _llama2_7b_chat() -> InstructionFinetuneConfig:
config = _llama3_1_instruct()

config.model = "llama3_1_8b_instruct"
config.max_seq_len = 4096
config.max_num_tokens = 4096 * 2
config.max_num_valid_tokens = 4096 * 2
config.model = "llama2_7b_chat"

return config


@instruction_finetune_preset("llama3_1_70b_instruct")
def _llama3_1_70b_instruct() -> InstructionFinetuneConfig:
config = _llama3_70b_instruct()
@instruction_finetune_preset("llama2_70b_chat")
def _llama2_70b_chat() -> InstructionFinetuneConfig:
config = _llama2_7b_chat()

config.model = "llama3_1_70b_instruct"
config.model = "llama2_70b_chat"
config.tensor_parallel_size = 8

return config

Expand Down
7 changes: 2 additions & 5 deletions src/fairseq2/recipes/lm/preference_finetune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
from fairseq2.recipes.lm.preference_finetune.recipe import (
preference_finetune_presets as preference_finetune_presets,
)
from fairseq2.recipes.lm.preference_finetune.recipe import (
preference_unit_factories as preference_unit_factories,
)
from fairseq2.recipes.lm.preference_finetune.recipe import (
from fairseq2.recipes.lm.preference_finetune.simpo import SimPOConfig as SimPOConfig
from fairseq2.recipes.lm.preference_finetune.utils import (
preference_unit_factory as preference_unit_factory,
)
from fairseq2.recipes.lm.preference_finetune.simpo import SimPOConfig as SimPOConfig

# isort: split

Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/lm/preference_finetune/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from fairseq2.logging import get_log_writer
from fairseq2.metrics.recorder import format_as_float, register_metric_formatter
from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input
from fairseq2.recipes.lm.preference_finetune.recipe import preference_unit_factory
from fairseq2.recipes.lm.preference_finetune.utils import (
PreferenceFinetuneMetricBag,
_gather_lprobs,
preference_unit_factory,
)
from fairseq2.recipes.trainer import AbstractTrainUnit

Expand Down
6 changes: 4 additions & 2 deletions src/fairseq2/recipes/lm/preference_finetune/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from fairseq2.logging import get_log_writer
from fairseq2.metrics.recorder import format_as_float, register_metric_formatter
from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input
from fairseq2.recipes.lm.preference_finetune.recipe import preference_unit_factory

# from fairseq2.recipes.lm.preference_finetune.recipe import preference_unit_factory
from fairseq2.recipes.lm.preference_finetune.utils import (
PreferenceFinetuneMetricBag,
_gather_lprobs,
_load_reference_model,
preference_unit_factory,
)
from fairseq2.recipes.trainer import AbstractTrainUnit
from fairseq2.recipes.utils.asset import AssetReference
Expand Down Expand Up @@ -170,7 +172,7 @@ class DpoConfig:
"""Holds the DPO configuration of a language model preference-finetuning task."""

# Reference Model
reference_model: AssetReference = "llama3_8b_instruct"
reference_model: AssetReference = "llama3_1_8b_instruct"
"""The name, path, or path to the asset card of the reference model."""

reference_dtype: DataType = torch.bfloat16
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/lm/preference_finetune/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from fairseq2.logging import get_log_writer
from fairseq2.metrics.recorder import format_as_float, register_metric_formatter
from fairseq2.models.sequence import SequenceModelOutput, as_auto_regressive_input
from fairseq2.recipes.lm.preference_finetune.recipe import preference_unit_factory
from fairseq2.recipes.lm.preference_finetune.utils import (
PreferenceFinetuneMetricBag,
_gather_lprobs,
preference_unit_factory,
)
from fairseq2.recipes.trainer import AbstractTrainUnit

Expand Down
69 changes: 28 additions & 41 deletions src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Mapping
from typing import Any, Literal

import torch
import torch.distributed
from torch.nn import Module

from fairseq2.assets import AssetNotFoundError, default_asset_store
from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager
Expand All @@ -24,16 +23,16 @@
PreferenceOptimizationBatch,
load_preference_optimization_dataset,
)
from fairseq2.factory_registry import ConfigBoundFactoryRegistry
from fairseq2.gang import Gang
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing
from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa
from fairseq2.optim import AdamWConfig, create_optimizer
from fairseq2.optim.lr_scheduler import CosineAnnealingLRConfig, create_lr_scheduler
from fairseq2.recipes.trainer import Trainer, TrainUnit
from fairseq2.recipes.lm.preference_finetune.dpo import DpoConfig
from fairseq2.recipes.lm.preference_finetune.utils import preference_unit_factories
from fairseq2.recipes.trainer import Trainer
from fairseq2.recipes.utils.asset import (
AssetReference,
asset_as_path,
Expand All @@ -49,11 +48,11 @@


@dataclass(kw_only=True)
class PreferenceOptimizationConfig:
class PreferenceFinetuningConfig:
"""Holds the configuration of a language model preference-finetuning task."""

# Data
dataset: AssetReference = "openeft" # TODO: change!
dataset: AssetReference = "gsm8k_dpo_data" # TODO: change!
"""The name, path, or path to the asset card of the preference optimization dataset."""

max_seq_len: int = 8192
Expand Down Expand Up @@ -122,7 +121,7 @@ class PreferenceOptimizationConfig:
criterion: str = "dpo"
"""The preference optimization criterion."""

criterion_config: Any = None
criterion_config: Any = field(default_factory=lambda: DpoConfig())
"""The configuration of the preference optimization criterion."""

# Optimizer, LR, and Loss
Expand Down Expand Up @@ -196,50 +195,45 @@ class PreferenceOptimizationConfig:
"""If ``True``, turns on anomaly detection feature in ``torch.autograd``."""


preference_finetune_presets = ConfigRegistry[PreferenceOptimizationConfig]()
preference_finetune_presets = ConfigRegistry[PreferenceFinetuningConfig]()

preference_finetune_preset = preference_finetune_presets.decorator


# @preference_finetune_preset("simpo")
# def _simpo() -> PreferenceOptimizationConfig:
# cfg = PreferenceOptimizationConfig()
# cfg.max_num_tokens = 1200
# cfg.max_seq_len = 600
# cfg.model = "llama3_8b"
# cfg.simpo = SimpoFinetuneConfig()
# return cfg
@dataclass(kw_only=True)
class DropoutConfig:
dropout_p: float = 0.0


@preference_finetune_preset("llama3_8b_instruct")
def _llama3_8b_instruct() -> PreferenceOptimizationConfig:
config = PreferenceOptimizationConfig()
@preference_finetune_preset("llama3_1_instruct")
def _llama3_1_instruct() -> PreferenceFinetuningConfig:
config = PreferenceFinetuningConfig()
config.model_config = DropoutConfig()
return config

config.max_seq_len = 1000
config.max_num_tokens = 1000
config.max_gradient_norm = 1.0

@preference_finetune_preset("llama3_1_instruct_constant_lr")
def _llama3_1_instruct_constant_lr() -> PreferenceFinetuningConfig:
config = _llama3_1_instruct()
# setting up final lr to be the optmiizer base lr, lr_mul is 1.0 by default
config.lr_scheduler_config.final_lr = config.optimizer_config.lr
return config


# batch size and min lengths are tuned for OA2 in this preset!
@preference_finetune_preset("llama3_70b_instruct_openassistant2")
def _llama3_70b_instruct_openassistant2() -> PreferenceOptimizationConfig:
config = PreferenceOptimizationConfig()
@preference_finetune_preset("llama3_1_70b_instruct")
def _llama3_70b_instruct() -> PreferenceFinetuningConfig:
config = _llama3_1_instruct()

# 70B DPO training might catch OOM, tune the effective batch size if needed.
config.max_seq_len = 200
config.max_num_tokens = 200
config.model = "llama3_70b_instruct"
config.model = "llama3_1_70b_instruct"
config.tensor_parallel_size = 8
config.max_gradient_norm = 1.0
config.gradient_accumulation = 8 # to address small batch size
config.criterion_config.reference_model = "llama3_1_70b_instruct"
config.criterion_config.reference_tensor_parallel_size = 8

return config


def load_preference_finetuner(
config: PreferenceOptimizationConfig, output_dir: Path
config: PreferenceFinetuningConfig, output_dir: Path
) -> Trainer[PreferenceOptimizationBatch]:
"""Load a :class:`Trainer` for language model preference optimization-finetuning."""
wall_watch = Stopwatch(start=True)
Expand Down Expand Up @@ -464,10 +458,3 @@ def load_preference_finetuner(
seed=config.seed,
wall_watch=wall_watch,
)


preference_unit_factories = ConfigBoundFactoryRegistry[
[Module, Gang, Mapping[str, Gang]], TrainUnit[PreferenceOptimizationBatch]
]()

preference_unit_factory = preference_unit_factories.decorator
6 changes: 4 additions & 2 deletions src/fairseq2/recipes/lm/preference_finetune/simpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
SequenceModelOutput,
as_auto_regressive_input,
)
from fairseq2.recipes.lm.preference_finetune.recipe import preference_unit_factory
from fairseq2.recipes.lm.preference_finetune.utils import PreferenceFinetuneMetricBag
from fairseq2.recipes.lm.preference_finetune.utils import (
PreferenceFinetuneMetricBag,
preference_unit_factory,
)
from fairseq2.recipes.trainer import AbstractTrainUnit

log = get_log_writer(__name__)
Expand Down
9 changes: 9 additions & 0 deletions src/fairseq2/recipes/lm/preference_finetune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from torcheval.metrics import Mean

from fairseq2.datasets.preference import PreferenceOptimizationBatch
from fairseq2.factory_registry import ConfigBoundFactoryRegistry
from fairseq2.gang import Gang
from fairseq2.logging import LogWriter
from fairseq2.metrics.recorder import format_as_float, register_metric_formatter
from fairseq2.models import load_model
from fairseq2.models.sequence import SequenceBatch, SequenceModelOutput
from fairseq2.nn.utils.module import freeze_parameters
from fairseq2.recipes.common_metrics import SequenceMetricBag
from fairseq2.recipes.trainer import TrainUnit
from fairseq2.recipes.utils.asset import AssetReference, retrieve_asset_card
from fairseq2.recipes.utils.setup import broadcast_model
from fairseq2.typing import META, DataType
Expand Down Expand Up @@ -149,3 +151,10 @@ def update_sequence_lengths(
Tensor([batch.rejected.num_target_elements() / batch.rejected.batch_size]),
weight=batch.rejected.batch_size,
)


preference_unit_factories = ConfigBoundFactoryRegistry[
[Module, Gang, Mapping[str, Gang]], TrainUnit[PreferenceOptimizationBatch]
]()

preference_unit_factory = preference_unit_factories.decorator

0 comments on commit 8ce7b48

Please sign in to comment.