Skip to content

Commit

Permalink
Revise empty value handling in CLI (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Sep 27, 2024
1 parent 0d46d9d commit 7089eb4
Show file tree
Hide file tree
Showing 14 changed files with 400 additions and 465 deletions.
8 changes: 4 additions & 4 deletions src/fairseq2/factory_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataClass
from fairseq2.utils.dataclass import fill_empty_fields
from fairseq2.utils.dataclass import merge_dataclass
from fairseq2.utils.structured import ValueConverter, get_value_converter

ConfigT = TypeVar("ConfigT", bound=DataClass)
Expand Down Expand Up @@ -73,7 +73,7 @@ def get(
unstructured_config: object = None,
base_config_name: str | None = None,
*,
allow_empty: bool = False,
set_empty: bool = False,
) -> ConfigBoundFactory[P, R]:
"""Return the factory with ``name``.
Expand All @@ -96,7 +96,7 @@ def get(
self._value_converter = get_value_converter()

config = self._value_converter.structure(
unstructured_config, config_kls, allow_empty=allow_empty
unstructured_config, config_kls, set_empty=set_empty
)

if base_config_name is None:
Expand All @@ -123,7 +123,7 @@ def get(
if config is None:
config = base_config
else:
fill_empty_fields(config, base_config)
config = merge_dataclass(base_config, config)

f = partial(factory, config)

Expand Down
6 changes: 3 additions & 3 deletions src/fairseq2/models/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from fairseq2.config_registry import ConfigRegistry
from fairseq2.typing import DataClass
from fairseq2.utils.dataclass import fill_empty_fields
from fairseq2.utils.dataclass import merge_dataclass
from fairseq2.utils.structured import (
StructuredError,
ValueConverter,
Expand Down Expand Up @@ -170,10 +170,10 @@ def __call__(
config = base_config
else:
config = self._value_converter.structure(
unstructured_config, config_kls, allow_empty=True
unstructured_config, config_kls, set_empty=True
)

fill_empty_fields(config, base_config)
config = merge_dataclass(base_config, config)

return config

Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_model(
- The model.
- The effective configuration of the model.
"""
factory = model_factories.get(family, unstructured_config, arch, allow_empty=True)
factory = model_factories.get(family, unstructured_config, arch, set_empty=True)

model = factory(device=device or CPU, dtype=dtype or torch.float32)

Expand Down
6 changes: 1 addition & 5 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.llama import llama_archs
from fairseq2.models.sequence import (
SequenceBatch,
SequenceModelOutput,
Expand All @@ -54,7 +53,6 @@
to_data_parallel,
)
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -88,9 +86,7 @@ class InstructionFinetuneConfig:
model: AssetReference = "llama3_1_8b_instruct"
"""The name or path to the asset card of the language model to finetune."""

model_config: Any = field(
default_factory=lambda: empty_(llama_archs.get("llama3_1_8b"))
)
model_config: Any = None
"""
The model configuration overrides. The provided values must be compatible
with the checkpoint; otherwise, the model will fail to load.
Expand Down
6 changes: 1 addition & 5 deletions src/fairseq2/recipes/lm/preference_finetune/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
from fairseq2.models.decoder import DecoderModel
from fairseq2.models.llama import llama_archs
from fairseq2.nn.checkpointing import use_layerwise_activation_checkpointing
from fairseq2.nn.transformer import enable_memory_efficient_torch_sdpa
from fairseq2.optim import AdamWConfig, create_optimizer
Expand All @@ -43,7 +42,6 @@
from fairseq2.recipes.utils.log import log_model
from fairseq2.recipes.utils.setup import compile_model, setup_gangs, to_data_parallel
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -84,9 +82,7 @@ class PreferenceOptimizationConfig:
model: AssetReference = "llama3_1_8b_instruct"
"""The name or path to the asset card of the language model to finetune."""

model_config: Any = field(
default_factory=lambda: empty_(llama_archs.get("llama3_1_8b"))
)
model_config: Any = None
"""
The model configuration overrides. The provided values must be compatible
with the checkpoint; otherwise, the model will fail to load.
Expand Down
11 changes: 1 addition & 10 deletions src/fairseq2/recipes/mt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from fairseq2.models import create_model
from fairseq2.models.encoder_decoder import EncoderDecoderModel
from fairseq2.models.seq2seq import Seq2SeqBatch
from fairseq2.models.transformer import transformer_archs
from fairseq2.optim import AdamWConfig, create_optimizer
from fairseq2.optim.lr_scheduler import MyleLRConfig, create_lr_scheduler
from fairseq2.recipes.common_metrics import Seq2SeqMetricBag
Expand All @@ -45,7 +44,6 @@
from fairseq2.recipes.utils.log import log_model, log_model_config
from fairseq2.recipes.utils.setup import setup_root_gang, to_data_parallel
from fairseq2.typing import CPU, META, DataType
from fairseq2.utils.dataclass import empty_
from fairseq2.utils.profiler import Stopwatch
from fairseq2.utils.rng import manual_seed

Expand Down Expand Up @@ -95,9 +93,7 @@ class MTTrainConfig:
model_arch: str | None = "nllb_dense_600m"
"""The architecture of the model."""

model_config: Any = field(
default_factory=lambda: empty_(transformer_archs.get("nllb_dense_600m"))
)
model_config: Any = None
"""The configuration of the model."""

dtype: DataType = torch.float16
Expand Down Expand Up @@ -200,16 +196,11 @@ class MTTrainConfig:

@mt_train_preset("nllb_dense_300m")
def _nllb_dense_300m() -> MTTrainConfig:
model_config = transformer_archs.get("nllb_dense_300m")

empty_(model_config)

config = _nllb_dense_600m()

assert isinstance(config.lr_scheduler_config, MyleLRConfig)

config.model_arch = "nllb_dense_300m"
config.model_config = model_config
config.lr_scheduler_config.num_warmup_steps = 400
config.gradient_accumulation = 4
config.max_num_steps = 10_000
Expand Down
110 changes: 80 additions & 30 deletions src/fairseq2/recipes/utils/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@

@final
class ConfigAction(Action):
"""Adds support for reading key-value pairs in format ``<key>=<yaml_value>``."""
"""
Adds support for reading configuration key-value pairs in format ``<key>=<yaml_value>``.
"""

def __init__(
self,
option_strings: list[str],
dest: str,
help: str | None = None,
self, option_strings: list[str], dest: str, help: str | None = None
) -> None:
super().__init__(
option_strings,
Expand All @@ -50,48 +49,99 @@ def __call__(
) -> None:
data: dict[str, Any] = {}

for item in values:
key_value = item.split("=", maxsplit=1)
if len(key_value) != 2:
raise ArgumentError(self, f"invalid key-value pair: {item}")
def get_parent_node(path: str) -> tuple[dict[str, Any], str]:
keys = path.split(".")

key, value = [kv.strip() for kv in key_value]
node = data

try:
parsed_value = yaml.safe_load(value)
except ParserError:
raise ArgumentError(
self, f"invalid key-value pair: {item} (value must be yaml)"
)
for key in keys[:-1]:
try:
child_node = node[key]
except KeyError:
child_node = None

fields = key.split(".")
if not isinstance(child_node, dict):
child_node = {}

tmp = data
node[key] = child_node

node = child_node

return node, keys[-1]

for item in values:
item = item.strip()

if item.startswith("del:"):
path = item[4:]

if "=" in path:
raise ArgumentError(self, f"key should not contain '=': {item}")

parent_node, key = get_parent_node(path)

for field in fields[:-1]:
try:
d = tmp[field]
del_keys = parent_node["_del_"]
except KeyError:
d = None
del_keys = None

if not isinstance(del_keys, list):
del_keys = []

if not isinstance(d, dict):
d = {}
parent_node["_del_"] = del_keys

tmp[field] = d
del_keys.append(key)
else:
path_value = item.split("=", maxsplit=1)
if len(path_value) != 2:
raise ArgumentError(self, f"invalid key-value pair: {item}")

tmp = d
path, value = path_value

tmp[fields[-1]] = parsed_value
try:
parsed_value = yaml.safe_load(value.lstrip())
except ParserError:
raise ArgumentError(
self, f"invalid key-value pair: {item} (value must be yaml)"
)

path = path.rstrip()

if path.startswith("add:"):
path = path[4:]

directive = "_add_"
elif path.startswith("set:"):
path = path[4:]

directive = "_set_"
else:
directive = "_set_"

parent_node, key = get_parent_node(path)

try:
directive_keys = parent_node[directive]
except KeyError:
directive_keys = None

if not isinstance(directive_keys, dict):
directive_keys = {}

parent_node[directive] = directive_keys

directive_keys[key] = parsed_value

setattr(namespace, self.dest, data)


def parse_dtype(value: str) -> DataType:
"""Parse ``value`` as a ``torch.dtype``."""
if value.startswith("torch."):
value = value[6:]

if isinstance(dtype := getattr(torch, value, None), DataType):
return dtype
dtype = getattr(torch, value, None)

if not isinstance(dtype, DataType):
raise ArgumentTypeError("must be a `torch.dtype` identifier")

raise ArgumentTypeError("must be a `torch.dtype` identifier")
return dtype
Loading

0 comments on commit 7089eb4

Please sign in to comment.