Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value #816

Merged
merged 2 commits into from
Sep 27, 2024
Merged

Value #816

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/fairseq2/factory_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataClass
from fairseq2.utils.dataclass import fill_empty_fields
from fairseq2.utils.dataclass import merge_dataclass
from fairseq2.utils.structured import ValueConverter, get_value_converter

ConfigT = TypeVar("ConfigT", bound=DataClass)
Expand Down Expand Up @@ -73,7 +73,7 @@ def get(
unstructured_config: object = None,
base_config_name: str | None = None,
*,
allow_empty: bool = False,
set_empty: bool = False,
) -> ConfigBoundFactory[P, R]:
"""Return the factory with ``name``.

Expand All @@ -96,7 +96,7 @@ def get(
self._value_converter = get_value_converter()

config = self._value_converter.structure(
unstructured_config, config_kls, allow_empty=allow_empty
unstructured_config, config_kls, set_empty=set_empty
)

if base_config_name is None:
Expand All @@ -123,7 +123,7 @@ def get(
if config is None:
config = base_config
else:
fill_empty_fields(config, base_config)
config = merge_dataclass(base_config, config)

f = partial(factory, config)

Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/models/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataClass
from fairseq2.utils.dataclass import fill_empty_fields
from fairseq2.utils.dataclass import merge_dataclass
from fairseq2.utils.structured import (
StructuredError,
ValueConverter,
Expand Down Expand Up @@ -170,10 +170,10 @@ def __call__(
config = base_config
else:
config = self._value_converter.structure(
unstructured_config, config_kls, allow_empty=True
unstructured_config, config_kls, set_empty=True
)

fill_empty_fields(config, base_config)
config = merge_dataclass(base_config, config)

return config

Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_model(
- The model.
- The effective configuration of the model.
"""
factory = model_factories.get(family, unstructured_config, arch, allow_empty=True)
factory = model_factories.get(family, unstructured_config, arch, set_empty=True)

model = factory(device=device or CPU, dtype=dtype or torch.float32)

Expand Down
6 changes: 1 addition & 5 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.llama import llama_archs
from fairseq2.models.sequence import (
SequenceBatch,
SequenceModelOutput,
Expand All @@ -54,7 +53,6 @@
to_data_parallel,
)
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -88,9 +86,7 @@ class InstructionFinetuneConfig:
model: AssetReference = "llama3_1_8b_instruct"
"""The name or path to the asset card of the language model to finetune."""

model_config: Any = field(
default_factory=lambda: empty_(llama_archs.get("llama3_1_8b"))
)
model_config: Any = None
"""
The model configuration overrides. The provided values must be compatible
with the checkpoint; otherwise, the model will fail to load.
Expand Down
6 changes: 1 addition & 5 deletions src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.llama import llama_archs
from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing
from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa
from fairseq2.optim import AdamWConfig, create_optimizer
Expand All @@ -43,7 +42,6 @@
from fairseq2.recipes.utils.log import log_model
from fairseq2.recipes.utils.setup import compile_model, setup_gangs, to_data_parallel
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -84,9 +82,7 @@ class PreferenceOptimizationConfig:
model: AssetReference = "llama3_1_8b_instruct"
"""The name or path to the asset card of the language model to finetune."""

model_config: Any = field(
default_factory=lambda: empty_(llama_archs.get("llama3_1_8b"))
)
model_config: Any = None
"""
The model configuration overrides. The provided values must be compatible
with the checkpoint; otherwise, the model will fail to load.
Expand Down
11 changes: 1 addition & 10 deletions src/fairseq2/recipes/mt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fairseq2.models import create_model
from fairseq2.models.encoder_decoder import EncoderDecoderModel
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.transformer import transformer_archs
from fairseq2.optim import AdamWConfig, create_optimizer
from fairseq2.optim.lr_scheduler import MyleLRConfig, create_lr_scheduler
from fairseq2.recipes.common_metrics import Seq2SeqMetricBag
Expand All @@ -45,7 +44,6 @@
from fairseq2.recipes.utils.log import log_model, log_model_config
from fairseq2.recipes.utils.setup import setup_root_gang, to_data_parallel
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -95,9 +93,7 @@ class MTTrainConfig:
model_arch: str | None = "nllb_dense_600m"
"""The architecture of the model."""

model_config: Any = field(
default_factory=lambda: empty_(transformer_archs.get("nllb_dense_600m"))
)
model_config: Any = None
"""The configuration of the model."""

dtype: DataType = torch.float16
Expand Down Expand Up @@ -200,16 +196,11 @@ class MTTrainConfig:

@mt_train_preset("nllb_dense_300m")
def _nllb_dense_300m() -> MTTrainConfig:
model_config = transformer_archs.get("nllb_dense_300m")

empty_(model_config)

config = _nllb_dense_600m()

assert isinstance(config.lr_scheduler_config, MyleLRConfig)

config.model_arch = "nllb_dense_300m"
config.model_config = model_config
config.lr_scheduler_config.num_warmup_steps = 400
config.gradient_accumulation = 4
config.max_num_steps = 10_000
Expand Down
110 changes: 80 additions & 30 deletions src/fairseq2/recipes/utils/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@

@final
class ConfigAction(Action):
"""Adds support for reading key-value pairs in format ``<key>=<yaml_value>``."""
"""
Adds support for reading configuration key-value pairs in format ``<key>=<yaml_value>``.
"""

def __init__(
self,
option_strings: list[str],
dest: str,
help: str | None = None,
self, option_strings: list[str], dest: str, help: str | None = None
) -> None:
super().__init__(
option_strings,
Expand All @@ -50,48 +49,99 @@ def __call__(
) -> None:
data: dict[str, Any] = {}

for item in values:
key_value = item.split("=", maxsplit=1)
if len(key_value) != 2:
raise ArgumentError(self, f"invalid key-value pair: {item}")
def get_parent_node(path: str) -> tuple[dict[str, Any], str]:
keys = path.split(".")

key, value = [kv.strip() for kv in key_value]
node = data

try:
parsed_value = yaml.safe_load(value)
except ParserError:
raise ArgumentError(
self, f"invalid key-value pair: {item} (value must be yaml)"
)
for key in keys[:-1]:
try:
child_node = node[key]
except KeyError:
child_node = None

fields = key.split(".")
if not isinstance(child_node, dict):
child_node = {}

tmp = data
node[key] = child_node

node = child_node

return node, keys[-1]

for item in values:
item = item.strip()

if item.startswith("del:"):
path = item[4:]

if "=" in path:
raise ArgumentError(self, f"key should not contain '=': {item}")

parent_node, key = get_parent_node(path)

for field in fields[:-1]:
try:
d = tmp[field]
del_keys = parent_node["_del_"]
except KeyError:
d = None
del_keys = None

if not isinstance(del_keys, list):
del_keys = []

if not isinstance(d, dict):
d = {}
parent_node["_del_"] = del_keys

tmp[field] = d
del_keys.append(key)
else:
path_value = item.split("=", maxsplit=1)
if len(path_value) != 2:
raise ArgumentError(self, f"invalid key-value pair: {item}")

tmp = d
path, value = path_value

tmp[fields[-1]] = parsed_value
try:
parsed_value = yaml.safe_load(value.lstrip())
except ParserError:
raise ArgumentError(
self, f"invalid key-value pair: {item} (value must be yaml)"
)

path = path.rstrip()

if path.startswith("add:"):
path = path[4:]

directive = "_add_"
elif path.startswith("set:"):
path = path[4:]

directive = "_set_"
else:
directive = "_set_"

parent_node, key = get_parent_node(path)

try:
directive_keys = parent_node[directive]
except KeyError:
directive_keys = None

if not isinstance(directive_keys, dict):
directive_keys = {}

parent_node[directive] = directive_keys

directive_keys[key] = parsed_value

setattr(namespace, self.dest, data)


def parse_dtype(value: str) -> DataType:
"""Parse ``value`` as a ``torch.dtype``."""
if value.startswith("torch."):
value = value[6:]

if isinstance(dtype := getattr(torch, value, None), DataType):
return dtype
dtype = getattr(torch, value, None)

if not isinstance(dtype, DataType):
raise ArgumentTypeError("must be a `torch.dtype` identifier")

raise ArgumentTypeError("must be a `torch.dtype` identifier")
return dtype
Loading
Loading