Skip to content

Commit

Permalink
Revise model sharder API (#813)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 26, 2024
1 parent d346a66 commit f2fc9b8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 40 deletions.
4 changes: 0 additions & 4 deletions src/fairseq2/assets/cards/models/llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,13 @@ name: llama3_70b
base: llama3
model_arch: llama3_70b
num_shards: 8
shard_embed_dim: false

---

name: llama3_70b_instruct
base: llama3_instruct
model_arch: llama3_70b
num_shards: 8
shard_embed_dim: false

---

Expand All @@ -125,12 +123,10 @@ name: llama3_1_70b
base: llama3
model_arch: llama3_1_70b
num_shards: 8
shard_embed_dim: false

---

name: llama3_1_70b_instruct
base: llama3_instruct
model_arch: llama3_1_70b
num_shards: 8
shard_embed_dim: false
30 changes: 13 additions & 17 deletions src/fairseq2/models/llama/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, final
from typing import Any, Mapping, final

from torch import Tensor
from typing_extensions import override
Expand Down Expand Up @@ -38,21 +38,6 @@
load_llama_config = StandardModelConfigLoader(LLAMA_FAMILY, LLaMAConfig, llama_archs)


@final
class LLaMAModelLoader(StandardModelLoader[TransformerDecoderModel, LLaMAConfig]):
"""Loads LLaMA models."""

@override
def _shard(
self, model: TransformerDecoderModel, gangs: dict[str, Gang], card: AssetCard
) -> None:
gang = gangs["tp"] # tensor parallel

shard_embed_dim = card.field("shard_embed_dim").get_as_(bool, True)

shard_transformer_decoder_model(model, gang, shard_embed_dim=shard_embed_dim)


def convert_llama_checkpoint(
checkpoint: dict[str, Any], config: LLaMAConfig
) -> dict[str, Any]:
Expand Down Expand Up @@ -130,10 +115,21 @@ def permute_rotary(w: Tensor, num_heads: int) -> Tensor:
return {"model": checkpoint}


load_llama_model = LLaMAModelLoader(
def shard_llama_model(
model: TransformerDecoderModel, config: LLaMAConfig, gangs: Mapping[str, Gang]
) -> None:
gang = gangs["tp"] # tensor parallel

shard_embed_dim = config.max_seq_len < 8192 # LLaMA 1 or 2

shard_transformer_decoder_model(model, gang, shard_embed_dim=shard_embed_dim)


load_llama_model = StandardModelLoader(
config_loader=load_llama_config,
factory=create_llama_model,
checkpoint_converter=convert_llama_checkpoint,
sharder=shard_llama_model,
)

load_model.register(LLAMA_FAMILY, load_llama_model)
Expand Down
64 changes: 45 additions & 19 deletions src/fairseq2/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

ModelT_co = TypeVar("ModelT_co", bound=Module, covariant=True)

ModelT_contra = TypeVar("ModelT_contra", bound=Module, contravariant=True)

ModelConfigT = TypeVar("ModelConfigT", bound=DataClass)

ModelConfigT_contra = TypeVar(
Expand Down Expand Up @@ -115,15 +117,27 @@ def __call__(
"""


class ModelSharder(Protocol[ModelT_contra, ModelConfigT_contra]):
def __call__(
self,
model: ModelT_contra,
config: ModelConfigT_contra,
gangs: Mapping[str, Gang],
) -> None:
...


@final
class StandardModelLoader(ModelLoader[ModelT], Generic[ModelT, ModelConfigT]):
"""Loads models of type ``ModelT``."""

_asset_store: AssetStore | None
_download_manager: AssetDownloadManager
_tensor_loader: TensorLoader
_checkpoint_converter: CheckpointConverter[ModelConfigT] | None
_config_loader: ModelConfigLoader[ModelConfigT]
_factory: ModelFactory[ModelConfigT, ModelT]
_checkpoint_converter: CheckpointConverter[ModelConfigT] | None
_sharder: ModelSharder[ModelT, ModelConfigT] | None
_restrict_checkpoints: bool
_skip_meta_init: bool

Expand All @@ -132,25 +146,19 @@ def __init__(
*,
config_loader: ModelConfigLoader[ModelConfigT],
factory: ModelFactory[ModelConfigT, ModelT],
restrict_checkpoints: bool = True,
skip_meta_init: bool = False,
asset_store: AssetStore | None = None,
download_manager: AssetDownloadManager | None = None,
tensor_loader: TensorLoader | None = None,
checkpoint_converter: CheckpointConverter[ModelConfigT] | None = None,
sharder: ModelSharder[ModelT, ModelConfigT] | None = None,
restrict_checkpoints: bool = True,
skip_meta_init: bool = False,
) -> None:
"""
:param config_loader:
The configuration loader.
:param factory:
The factory to construct models.
:param restrict_checkpoints:
If ``True``, restricts the Python unpickler to load only tensors,
primitive types, and dictionaries.
:param skip_meta_init:
If ``True``, skips meta device initialization and constructs the
model directly on the requested device. Should be used with models
that do not support PyTorch's ``reset_parameters()`` convention.
:param asset_store:
The asset store where to check for available models. If ``None``,
the default asset store will be used.
Expand All @@ -162,17 +170,26 @@ def __init__(
:param checkpoint_converter:
The converter to which loaded checkpoints will be passed for further
processing.
:param sharder:
The model sharder for tensor parallelism.
:param restrict_checkpoints:
If ``True``, restricts the Python unpickler to load only tensors,
primitive types, and dictionaries.
:param skip_meta_init:
If ``True``, skips meta device initialization and constructs the
model directly on the requested device. Should be used with models
that do not support PyTorch's ``reset_parameters()`` convention.
"""
self._asset_store = asset_store
self._download_manager = download_manager or default_download_manager
self._tensor_loader = tensor_loader or load_tensors
self._checkpoint_converter = checkpoint_converter
self._config_loader = config_loader
self._factory = factory
self._checkpoint_converter = checkpoint_converter
self._sharder = sharder
self._restrict_checkpoints = restrict_checkpoints
self._skip_meta_init = skip_meta_init

@final
def __call__(
self,
model_name_or_card: str | AssetCard,
Expand Down Expand Up @@ -243,7 +260,14 @@ def __call__(
) from ex

if gang is not None and gang.size > 1:
self._shard(model, gangs, card) # type: ignore[arg-type]
if self._sharder is None:
raise RuntimeError(
f"{card.name} has a sharded checkpoint, but has no model sharder. Please file a bug report to the model author."
)

assert gangs is not None

self._sharder(model, config, gangs)

return model

Expand Down Expand Up @@ -294,7 +318,14 @@ def __call__(
model = self._factory(config, device=init_device, dtype=dtype)

if gang is not None and gang.size > 1:
self._shard(model, gangs, card) # type: ignore[arg-type]
if self._sharder is None:
raise RuntimeError(
f"{card.name} has a sharded checkpoint, but has no model sharder. Please file a bug report to the model author."
)

assert gangs is not None

self._sharder(model, config, gangs)

try:
model_device = infer_device(model, name="model")
Expand Down Expand Up @@ -338,11 +369,6 @@ def __call__(

return model

def _shard(self, model: ModelT, gangs: dict[str, Gang], card: AssetCard) -> None:
raise RuntimeError(
f"{card.name} has a sharded checkpoint, but has no model sharder. Please file a bug report to the model author."
)


@final
class DelegatingModelLoader(ModelLoader[ModelT]):
Expand Down

0 comments on commit f2fc9b8

Please sign in to comment.