Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to mixed precision training #822

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/fairseq2/optim/dynamic_loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DynamicLossScaler:
_optimizer: Optimizer
_scale_window: int
_min_scale: float
_is_enabled: bool
_enabled: bool

# compat: consolidate into `GradScaler` once we cease support for PT2.2
_grad_scaler: GradScaler | ShardedGradScaler
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
self._optimizer = optimizer
self._scale_window = scale_window
self._min_scale = min_scale
self._is_enabled = enabled
self._enabled = enabled

def state_dict(self) -> dict[str, Any]:
return {"grad_scaler": self._grad_scaler.state_dict()}
Expand Down Expand Up @@ -194,7 +194,7 @@ def _are_close(a: float, b: float) -> bool:

def unscale_gradients_(self) -> None:
"""Unscale the associated optimizer's gradients by the current scale."""
if not supports_manual_gradient_scaling(self._optimizer):
if self._enabled and not supports_manual_gradient_scaling(self._optimizer):
raise RuntimeError(
"`optimizer` must support manual gradient scaling via `torch.cuda.amp.GradScaler`, but supports only implicit scaling in its step function (i.e. `_step_supports_amp_scaling == True`)."
)
Expand All @@ -212,7 +212,7 @@ def get_scale(self) -> float:
@property
def is_enabled(self) -> bool:
"""``True`` if the loss scaling is enabled."""
return self._is_enabled
return self._enabled


@final
Expand Down
24 changes: 22 additions & 2 deletions src/fairseq2/recipes/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import AbstractContextManager, nullcontext
from itertools import count
from pathlib import Path
from typing import Generic, TypeVar, final
Expand All @@ -29,7 +30,7 @@
)
from fairseq2.recipes.common_metrics import extend_batch_metrics
from fairseq2.recipes.utils.cli import create_rich_progress
from fairseq2.typing import CPU
from fairseq2.typing import CPU, DataType
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import RngBag

Expand Down Expand Up @@ -104,6 +105,8 @@ class Evaluator(Generic[BatchT]):
_root_gang: Gang
_dp_gang: Gang
_tp_gang: Gang
_dtype: DataType
_amp: bool
_metric_recorders: list[MetricRecorder]
_seed: int
_wall_watch: Stopwatch
Expand All @@ -118,6 +121,8 @@ def __init__(
wall_watch: Stopwatch,
dp_gang: Gang | None = None,
tp_gang: Gang | None = None,
dtype: DataType = torch.float32,
amp: bool = False,
tb_dir: Path | None = None,
metrics_dir: Path | None = None,
seed: int = 2,
Expand All @@ -135,6 +140,10 @@ def __init__(
The data parallel gang. If ``None``, ``root_gang`` will be used.
:param tp_gang:
The tensor parallel gang. Only required for tensor parallel models.
:param dtype:
The data type of the model.
:param amp:
If ``True``, enables ``torch.amp``.
:param tb_dir:
The TensorBoard log directory to dump metrics.
:param metrics_dir:
Expand Down Expand Up @@ -168,6 +177,10 @@ def __init__(
f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead."
)

self._dtype = dtype

self._amp = amp

if root_gang.rank == 0:
self._metric_recorders = [LogMetricRecorder(log)]

Expand Down Expand Up @@ -239,12 +252,19 @@ def _evaluate_unit(
break

for batch in batches:
unit(batch)
with self._maybe_autocast():
unit(batch)

num_effective_batches += 1

self._publish_metrics(unit, num_effective_batches, watch.get_elapsed_time())

def _maybe_autocast(self) -> AbstractContextManager[None]:
if self._dtype == torch.float32 or not self._amp:
return nullcontext()

return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype)

def _publish_metrics(
self, unit: EvalUnit[BatchT], num_batches: int, elapsed_time: float
) -> None:
Expand Down
24 changes: 22 additions & 2 deletions src/fairseq2/recipes/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from contextlib import AbstractContextManager, nullcontext
from itertools import count
from pathlib import Path
from typing import Generic, TypeVar, final
Expand All @@ -27,7 +28,7 @@
)
from fairseq2.recipes.common_metrics import extend_batch_metrics
from fairseq2.recipes.utils.cli import create_rich_progress
from fairseq2.typing import CPU
from fairseq2.typing import CPU, DataType
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import RngBag

Expand Down Expand Up @@ -79,6 +80,8 @@ class Generator(Generic[BatchT]):
_root_gang: Gang
_dp_gang: Gang
_tp_gang: Gang
_dtype: DataType
_amp: bool
_metric_recorders: list[MetricRecorder]
_seed: int
_wall_watch: Stopwatch
Expand All @@ -93,6 +96,8 @@ def __init__(
wall_watch: Stopwatch,
dp_gang: Gang | None = None,
tp_gang: Gang | None = None,
dtype: DataType = torch.float32,
amp: bool = False,
metrics_dir: Path | None = None,
seed: int = 2,
) -> None:
Expand All @@ -109,6 +114,10 @@ def __init__(
The data parallel gang. If ``None``, ``gang`` will be used.
:param tp_gang:
The tensor parallel gang. Only required for tensor parallel models.
:param dtype:
The data type of the model.
:param amp:
If ``True``, enables ``torch.amp``.
:param metrics_dir:
The directory to dump metrics.
:param seed:
Expand All @@ -135,6 +144,10 @@ def __init__(
f"The coordinator process of `root_gang` (i.e. rank 0) must be rank 0 in `dp_gang` and `tp_gang`, but is {self._dp_gang.rank} and {self._tp_gang.rank} instead."
)

self._dtype = dtype

self._amp = amp

if root_gang.rank == 0:
self._metric_recorders = [LogMetricRecorder(log)]

Expand Down Expand Up @@ -194,12 +207,19 @@ def _do_run(self) -> None:
break

for batch in batches:
self._unit(batch)
with self._maybe_autocast():
self._unit(batch)

num_effective_batches += 1

self._publish_metrics(num_effective_batches, watch.get_elapsed_time())

def _maybe_autocast(self) -> AbstractContextManager[None]:
if self._dtype == torch.float32 or not self._amp:
return nullcontext()

return torch.autocast(device_type=self._dp_gang.device.type, dtype=self._dtype)

def _publish_metrics(self, num_batches: int, elapsed_time: float) -> None:
log.debug("Syncing metrics.")

Expand Down
37 changes: 25 additions & 12 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class InstructionFinetuneConfig:
train_split: str = "default"
"""The name of the train data split."""

valid_split: str = "valid"
valid_split: str | None = None
"""The name of the valid data split."""

max_seq_len: int = 8192
Expand All @@ -80,7 +80,7 @@ class InstructionFinetuneConfig:
max_num_tokens: int = 8192 * 2
"""The maximum number of tokens per batch."""

max_num_valid_tokens: int = 8192 * 2
max_num_valid_tokens: int | None = None
"""The maximum number of tokens per validation batch."""

example_shuffle_window: int = 10_000
Expand All @@ -105,8 +105,14 @@ class InstructionFinetuneConfig:
dtype: DataType = torch.bfloat16
"""The data type of the model."""

mixed_precision: bool = True
"""If ``True``, the model will be trained in mixed precision."""
mixed_precision: Literal["none", "static", "dynamic"] = "static"
"""
If 'none', the whole training will be run in `dtype`. If 'static', forward
and backward passes will be run in `dtype`, but the optimizer step will be
run in full precision. If 'dynamic', forward and backward passes will be run
with `torch.amp` in `dtype`, but the optimizer step will be run in full
precision.
"""

data_parallelism: Literal["ddp", "fsdp"] = "fsdp"
"""The data parallelism API to use."""
Expand Down Expand Up @@ -161,9 +167,6 @@ class InstructionFinetuneConfig:
max_num_data_epochs: int | None = None
"""The maximum number of data epochs to train for."""

validate: bool = False
"""If ``True``, runs validation."""

validate_after_n_steps: int = 0
"""The number of steps after which to start validating the model."""

Expand Down Expand Up @@ -323,7 +326,7 @@ def load_instruction_finetuner(

init_device = META

dtype = torch.float32 if config.mixed_precision else config.dtype
dtype = config.dtype if config.mixed_precision == "none" else torch.float32

has_checkpoint = checkpoint_manager.has_checkpoint()

Expand Down Expand Up @@ -372,14 +375,16 @@ def load_instruction_finetuner(

checkpoint_manager.save_model_metadata(base_asset=model_card.name)

mp_dtype = config.dtype if config.mixed_precision == "static" else None

dp_model = to_data_parallel(
model,
dp_gang,
config.data_parallelism,
log,
fsdp_broadcast_state=not has_checkpoint,
fsdp_reshard_after_forward=config.fsdp_reshard_after_forward,
fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None,
fsdp_mixed_precision_dtype=mp_dtype,
fsdp_fp32_reduce=True,
fsdp_wrap_granularity=config.fsdp_wrap_granularity,
)
Expand Down Expand Up @@ -447,16 +452,18 @@ def load_instruction_finetuner(
) from ex

# Initialize the validation unit.
if config.validate:
if config.valid_split is not None:
valid_unit = InstructionValidUnit(criterion, dp_gang)

max_num_tokens = config.max_num_valid_tokens or config.max_num_tokens

try:
valid_data_reader = dataset.create_reader(
config.valid_split,
tokenizer,
dp_gang,
config.max_seq_len,
batching=LengthBatching(config.max_num_valid_tokens),
batching=LengthBatching(max_num_tokens),
example_shuffle_window=config.example_shuffle_window,
batch_shuffle_window=config.batch_shuffle_window,
sync_mode="until_last",
Expand All @@ -479,6 +486,12 @@ def load_instruction_finetuner(

seed += 1

# TODO: Fix once we support static mixed precision on one device.
if config.mixed_precision == "static":
amp = root_gang.size == 1 or config.data_parallelism != "fsdp"
else:
amp = config.mixed_precision == "dynamic"

# Initialize the trainer.
return Trainer[SequenceBatch](
unit=unit,
Expand All @@ -489,9 +502,9 @@ def load_instruction_finetuner(
dtype=config.dtype,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
amp=config.mixed_precision,
fp16_loss_scale=config.fp16_loss_scale,
max_gradient_norm=config.max_gradient_norm,
amp=amp,
max_num_steps=config.max_num_steps,
max_num_data_epochs=config.max_num_data_epochs,
valid_units=valid_units,
Expand Down
24 changes: 19 additions & 5 deletions src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,14 @@ class PreferenceOptimizationConfig:
dtype: DataType = torch.bfloat16
"""The data type of the model."""

mixed_precision: bool = True
"""If ``True``, the model will be trained in mixed precision."""
mixed_precision: Literal["none", "static", "dynamic"] = "static"
"""
If 'none', the whole training will be run in `dtype`. If 'static', forward
and backward passes will be run in `dtype`, but the optimizer step will be
run in full precision. If 'dynamic', forward and backward passes will be run
with `torch.amp` in `dtype`, but the optimizer step will be run in full
precision.
"""

data_parallelism: Literal["ddp", "fsdp"] = "fsdp"
"""The data parallelism API to use."""
Expand Down Expand Up @@ -289,7 +295,7 @@ def load_preference_finetuner(

init_device = META

dtype = torch.float32 if config.mixed_precision else config.dtype
dtype = config.dtype if config.mixed_precision == "none" else torch.float32

has_checkpoint = checkpoint_manager.has_checkpoint()

Expand Down Expand Up @@ -333,14 +339,16 @@ def load_preference_finetuner(

checkpoint_manager.save_model_metadata(base_asset=model_card.name)

mp_dtype = config.dtype if config.mixed_precision == "static" else None

dp_model = to_data_parallel(
model,
dp_gang,
config.data_parallelism,
log,
fsdp_broadcast_state=not has_checkpoint,
fsdp_reshard_after_forward=config.fsdp_reshard_after_forward,
fsdp_mixed_precision_dtype=config.dtype if config.mixed_precision else None,
fsdp_mixed_precision_dtype=mp_dtype,
fsdp_fp32_reduce=True,
fsdp_wrap_granularity=config.fsdp_wrap_granularity,
)
Expand Down Expand Up @@ -422,6 +430,12 @@ def load_preference_finetuner(
"The learning rate scheduler cannot be created. See nested exception for details."
) from ex

# TODO: Fix once we support static mixed precision on one device.
if config.mixed_precision == "static":
amp = root_gang.size == 1 or config.data_parallelism != "fsdp"
else:
amp = config.mixed_precision == "dynamic"

# Initialize the trainer.
return Trainer[PreferenceOptimizationBatch](
unit=unit,
Expand All @@ -432,9 +446,9 @@ def load_preference_finetuner(
dtype=config.dtype,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
amp=config.mixed_precision,
fp16_loss_scale=config.fp16_loss_scale,
max_gradient_norm=config.max_gradient_norm,
amp=amp,
max_num_steps=config.max_num_steps,
max_num_data_epochs=config.max_num_data_epochs,
checkpoint_manager=checkpoint_manager,
Expand Down
5 changes: 5 additions & 0 deletions src/fairseq2/recipes/lm/text_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class TextGenerateConfig:
dtype: DataType = torch.bfloat16
"""The data type of the model."""

amp: bool = False
"""If ``True``, runs evaluation with ``torch.amp``."""

tensor_parallel_size: int = 1
"""The size of tensor parallelism."""

Expand Down Expand Up @@ -301,6 +304,8 @@ def load_text_generator(
root_gang=root_gang,
dp_gang=dp_gang,
tp_gang=tp_gang,
dtype=config.dtype,
amp=config.amp,
metrics_dir=output_dir.joinpath("metrics"),
seed=seed,
wall_watch=wall_watch,
Expand Down
Loading
Loading