From 9dae7f16429f7b591b8e6ec91c902bf0e488eb1a Mon Sep 17 00:00:00 2001 From: Andrew Ho Date: Mon, 16 Dec 2024 14:01:59 -0500 Subject: [PATCH] torchdata integration - multi-dataset and streaming support (#1929) --- .../11B_lora_multi_dataset.yaml | 122 +++ ...lora_finetune_distributed_multi_dataset.py | 963 ++++++++++++++++++ torchtune/_recipe_registry.py | 11 + torchtune/data/_torchdata.py | 51 + torchtune/data/_utils.py | 185 +++- torchtune/datasets/_sft.py | 70 +- torchtune/datasets/multimodal/__init__.py | 3 +- .../datasets/multimodal/_the_cauldron.py | 46 +- torchtune/utils/_import_guard.py | 11 + 9 files changed, 1437 insertions(+), 25 deletions(-) create mode 100644 recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml create mode 100644 recipes/lora_finetune_distributed_multi_dataset.py create mode 100644 torchtune/data/_torchdata.py diff --git a/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml b/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml new file mode 100644 index 0000000000..87afa718b2 --- /dev/null +++ b/recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml @@ -0,0 +1,122 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py +# using a Llama3.2 11B Vision Instruct model +# +# This config assumes that you've run the following command before launching: +# tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct --ignore-patterns "original/consolidated*" +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training: +# tune run --nproc_per_node 2 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA finetuning please use 11B_lora_single_device.yaml +# or 11B_qlora_single_device.yaml + +# Model arguments +model: + _component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b + decoder_trainable: "frozen" + encoder_trainable: "lora" + fusion_trainable: "lora" + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + image_size: 560 # Make sure this matches the image_size in tokenizer + +# Transform +tokenizer: + _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform + path: /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model + image_size: 560 + max_seq_len: 8192 + +# Checkpointer +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00005" + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/ + model_type: LLAMA3_VISION +resume_from_checkpoint: False +save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. + +# TorchData setup +dataloader: + shuffle: True + collate_fn: torchtune.data.padded_collate_tiled_images_and_mask + parallel_method: thread + num_workers: 4 # Per dataset + pin_memory: true + packed: False # Set to true for great speed ups + prefetch_factor: 2 +seed: null + +datasets: + - source: HuggingFaceM4/the_cauldron + subset: ocrvqa + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + - source: HuggingFaceM4/the_cauldron + subset: dvqa + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + - source: HuggingFaceM4/the_cauldron + subset: docvqa + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + - source: HuggingFaceM4/the_cauldron + subset: tabmwp + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + +# Fine-tuning arguments +epochs: 1 +# max_steps_per_epoch is required for progress bar +max_steps_per_epoch: 50 +batch_size: 4 +gradient_accumulation_steps: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 1e-4 + +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +clip_grad_norm: 1.0 +compile: True # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +dtype: bf16 + +# Logging +output_dir: /tmp/lora-llama3.2-vision-finetune +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/lora_finetune_distributed_multi_dataset.py b/recipes/lora_finetune_distributed_multi_dataset.py new file mode 100644 index 0000000000..7cf5ee62f2 --- /dev/null +++ b/recipes/lora_finetune_distributed_multi_dataset.py @@ -0,0 +1,963 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, List, Optional, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer + +from torchdata.nodes import Loader, StopCriteria +from torchtune import config, modules, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed +from torchtune.data._utils import get_dataloader, get_multi_dataset, load_hf_dataset +from torchtune.datasets._sft import SFTTransform +from torchtune.modules.peft import ( + DoRALinear, + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + LoRALinear, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class LoRAFinetuneRecipeDistributed(FTRecipeInterface): + """ + Distributed LoRA finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - TorchData. Map and Streaming HuggingFace datasets, and multi-dataset mixing. + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5.0 or later and will be + enabled by default if an acceptable torch version is found. Activation offloading can be used in + conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + ValueError: If world_size is 1 + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + _, rank = training.get_world_size_and_rank() + + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + should_load_recipe_state=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + # When resuming from checkpoint for LoRA, the recipe expects the adapter weights + # and recipe state to be present. The keys should match up with what ``save_checkpoint`` + # used to create these intermediate checkpoints + if self._resume_from_checkpoint: + if training.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + self._compile = cfg.get("compile", False) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[training.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + utils.log_rank_zero(log, "Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._dataloader = self._setup_data( + cfg_dataloader=cfg.dataloader, + cfg_datasets=cfg.datasets, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + utils.log_rank_zero( + log, f" Profiler config after instantiation: {profiler_cfg}" + ) + if self._is_rank_zero: + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + base_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` + """ + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + utils.log_rank_zero( + log, + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...", + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + set_trainable_params(model, get_adapter_params(model)) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + if lora_weights_state_dict: + lora_missing, lora_unexpected = training.load_from_full_model_state_dict( + model, + lora_weights_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + else: + lora_missing, lora_unexpected = None, None + + # Initialize LoRA params and RoPE buffers + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if ( + isinstance(m, LoRALinear) or isinstance(m, DoRALinear) + ) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.lora_a.to_empty(device=lora_device) + m.lora_b.to_empty(device=lora_device) + m.to_empty(device=lora_device) + m.initialize_parameters() + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + base_missing, base_unexpected = training.load_from_full_model_state_dict( + model, + base_model_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + m.initialize_dora_magnitude() + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # log + utils.log_rank_zero( + log, + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs", + ) + if self._is_rank_zero: + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + utils.log_rank_zero(log, "Optimizer is initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + utils.log_rank_zero(log, "Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataloader: DictConfig, + cfg_datasets: ListConfig, + batch_size: int, + ) -> Loader: + """ + Torchdata related setup happens here. Currently this recipe supports + both Map and Streaming datasets (from HuggingFace datasets), and mixing multiple + datasets (can be mix of Map and Streaming). + """ + # Get global settings + shuffle = cfg_dataloader.shuffle + parallel_method = cfg_dataloader.get("parallel_method", "thread") + packed = cfg_dataloader.get("packed", False) + streaming = cfg_dataloader.get("streaming", False) + num_workers = cfg_dataloader.get("num_workers", 0) + pin_memory = cfg_dataloader.get("pin_memory", True) + collate_fn = cfg_dataloader.collate_fn + prefetch_factor = cfg_dataloader.get("prefetch_factor", 6) + + if packed: + raise ValueError("Packing not yet supported") + + # Multi-Dataset Stop Criteria + stop_criteria = cfg_dataloader.get( + "stop_criteria", StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED + ) + weights, datasets = {}, {} + for idx, cfg_dataset in enumerate(cfg_datasets): + dataset_name = cfg_dataset.pop("name", None) + if dataset_name is None: + dataset_name = cfg_dataset.get("subset", None) + key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "") + + utils.log_rank_zero(log, f"Instantiating dataset {cfg_dataset}") + # Handle dataset-specific overrides, fallback to cfg_dataloader settings + ds_streaming = cfg_dataset.pop("streaming", streaming) + ds_shuffle = cfg_dataset.pop("shuffle", shuffle) + ds_parallel_method = cfg_dataset.pop("parallel_method", parallel_method) + ds_num_workers = cfg_dataset.pop("num_workers", num_workers) + + # Instantiate dataset transform + assert "transform" in cfg_dataset, "transform must be specified in dataset" + transform = config.instantiate(cfg_dataset.pop("transform")) + + weights[key] = float(cfg_dataset.pop("weight")) + datasets[key] = load_hf_dataset( + **cfg_dataset, + transform=transform, + streaming=ds_streaming, + shuffle=ds_shuffle, + parallel_method=ds_parallel_method, + num_workers=ds_num_workers, + ) + + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + + collate_fn = ( + partial( + _get_component_from_path(collate_fn), + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ) + if len(datasets) > 1: + dataset = get_multi_dataset( + datasets=datasets, + weights=weights, + stop_criteria=stop_criteria, + ) + else: + dataset = next(iter(datasets.values())) + + loader = get_dataloader( + dataset=dataset, + model_transform=SFTTransform(model_transform=self._tokenizer), + batch_size=batch_size, + collate_fn=collate_fn, + drop_last=True, + num_workers=num_workers, + parallel_method=parallel_method, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, + ) + + utils.log_rank_zero(log, "TorchData nodes are initialized") + + return loader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + Checkpointer will save the merged weights, adapter weights and recipe state in + different checkpoint files. To correctly resume from training, the adapter weights + and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + utils.log_rank_zero( + log, + "Saving checkpoint. This may take some time. Retrieving full model state dict...", + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, + self._is_rank_zero, + device=self._device, + ) + utils.log_rank_zero( + log, + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", + ) + + if intermediate_checkpoint: + utils.log_rank_zero(log, "Retrieving optimizer state dict...") + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + utils.log_rank_zero( + log, + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs", + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + start = time.perf_counter() + + if self._save_adapter_weights_only: + adapter_state_dict = cpu_state_dict + else: + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + world_size, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_fn(logits, labels) * current_num_tokens + + # free logits otherwise it peaks backward memory + del logits + + running_loss += current_loss + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + if self._is_rank_zero: + log.info(f"End of epoch {self.epochs_run}!") + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + init_process_group("cuda:nccl,cpu:gloo") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + + config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) + + recipe = LoRAFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index eb3d24add3..faf1ec7124 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -399,6 +399,17 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="lora_finetune_distributed_multi_dataset", + file_path="lora_finetune_distributed_multi_dataset.py", + configs=[ + Config( + name="llama3_2_vision/11B_lora_multi_dataset", + file_path="llama3_2_vision/11B_lora_multi_dataset.yaml", + ), + ], + supports_distributed=True, + ), Recipe( name="generate", file_path="generate.py", diff --git a/torchtune/data/_torchdata.py b/torchtune/data/_torchdata.py new file mode 100644 index 0000000000..d39b0824c9 --- /dev/null +++ b/torchtune/data/_torchdata.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar + +from torchtune.utils._import_guard import _TORCHDATA_INSTALLED, _TORCHDATA_MIN_VERSION + +from typing_extensions import TypeAlias + + +if _TORCHDATA_INSTALLED: + from torchdata.nodes import BaseNode, Loader # noqa +else: + # If we fail to import torchdata, define stubs to make typechecker happy + T = TypeVar("T") + + class BaseNode(Iterator[T]): + def __init__(self, *args, **kwargs): + pass + + class Loader(Iterable): + def __init__(self, *args, **kwargs): + assert_torchdata_installed() + + +DatasetType: TypeAlias = BaseNode[Mapping[str, Any]] # type: ignore + + +def assert_torchdata_installed(): + if not _TORCHDATA_INSTALLED: + raise ImportError( + f"torchdata is not installed, or the current version is too old. " + f"Please (re-)install it with `pip install torchdata>={_TORCHDATA_MIN_VERSION}`. " + ) + + +def requires_torchdata(func: Callable) -> Callable: + """ + Decorator to check if torchdata is installed and raise an ImportError if not. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + assert_torchdata_installed() + return func(*args, **kwargs) + + return wrapper diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index 832e1babca..812d1617a1 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -5,9 +5,18 @@ # LICENSE file in the root directory of this source tree. from pathlib import Path -from typing import Any, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union from urllib import request +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node +from torch.utils.data import default_collate, DistributedSampler + +from torchtune.data._torchdata import DatasetType, Loader, requires_torchdata +from torchtune.modules.transforms import Transform + +from torchtune.utils import get_world_size_and_rank + T = TypeVar("T", bound=type) @@ -142,3 +151,177 @@ def format_content_with_images( final_content_list.append({"type": "image", "content": images.pop(0)}) return final_content_list + + +def chain(*funcs: List[Callable]) -> Callable: + """ + Chain a list of functions together into a single function. + + Args: + *funcs (List[Callable]): list of functions to chain together + + Returns: + Callable: chained function + """ + + def chained_fn(x): + for fn in funcs: + x = fn(x) + return x + + return chained_fn + + +@requires_torchdata +def load_hf_dataset( + source: str, + transform: Transform, + filter_fn: Optional[Callable] = None, + shuffle: bool = True, + seed: int = 0, + num_workers: int = 0, + parallel_method: Literal["process", "thread"] = "thread", + streaming: bool = False, + **load_dataset_kwargs: Dict[str, Any], +) -> DatasetType: + """ + Load a HuggingFace dataset (Map or Streaming) and apply a Transform to it. + + Args: + source (str): HuggingFace dataset source. + transform (Transform): Transform to apply to the samples of the dataset. + filter_fn (Optional[Callable]): Filter function to pass to HuggingFace dataset. + shuffle (bool): Whether to shuffle the dataset. Default is True. For streaming datasets, this is passed to + HuggingFace dataset as .shuffle(). For map datasets, a DistributedSampler is used. + seed (int): Seed for the random number generator in the case of Map style dataset shuffling. Default is 0. + num_workers (int): Number of workers to use for loading the dataset. Default is 0 (no parallelism). Setting this + greater than 0 will create `parallel_method` workers to perform transforms to the dataset. + parallel_method (Literal["process", "thread"]): Method to use for parallelism. Default is "thread". No effect if + num_workers is 0. + streaming (bool): whether to load a streaming vs map-style dataset. Default False. + **load_dataset_kwargs (Dict[str, Any]): Additional Keyword arguments to pass to HuggingFace dataset. See Hugging Face's + documentation. + + Returns: + A ``torchdata.nodes`` iterator that can be passed directly to a Loader, or combined with other-datasets in a multi-dataset + sampler. + """ + from torchdata.nodes import IterableWrapper, ParallelMapper, SamplerWrapper + + if "subset" in load_dataset_kwargs: + assert ( + "name" not in load_dataset_kwargs + ), f"found both 'subset' and 'name' found, you may only specify one, {load_dataset_kwargs=}" + load_dataset_kwargs["name"] = load_dataset_kwargs.pop("subset") + dataset = load_dataset(source, **load_dataset_kwargs) + if filter_fn is not None: + dataset = dataset.filter(filter_fn) + + world_size, rank = get_world_size_and_rank() + if streaming: + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + if shuffle: + dataset = dataset.shuffle(seed=seed) + node = IterableWrapper(dataset) + else: + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=seed, + ) + # Note: SamplerWrapper will call set_epoch on the sampler (if defined), + # and auto-increment the epoch each time the node is reset. + node = SamplerWrapper(sampler) + transform = chain(dataset.__getitem__, transform) # type: ignore + + node = ParallelMapper( + node, map_fn=transform, num_workers=num_workers, method=parallel_method + ) + + return node + + +@requires_torchdata +def get_multi_dataset( + datasets: Dict[str, DatasetType], + weights: Dict[str, float], + stop_criteria: str = "CYCLE_UNTIL_ALL_DATASETS_EXHASTED", + seed: int = 0, +) -> DatasetType: + """ + Given a dictionary of datasets and their corresponding weights, return a dataset that + samples from the given datasets according to the specified weights. + + Args: + datasets (Dict[str, DatasetType]): dictionary of datasets + weights (Dict[str, float]): dictionary of weights for each dataset. If not + stop_criteria (str): stop criteria for the sampler. Default "CYCLE_UNTIL_ALL_DATASETS_EXHASTED". + See also: torchdata.nodes.StopCriteria + seed (int): seed for the random number generator. Default 0. + + Returns: + A ``torchdata.nodes`` iterator which can be passed to Loader, or further composed with other Nodes. + """ + from torchdata.nodes import MultiNodeWeightedSampler + + return MultiNodeWeightedSampler( + source_nodes=datasets, + weights=weights, + stop_criteria=stop_criteria, + seed=seed, + ) + + +@requires_torchdata +def get_dataloader( + dataset: DatasetType, + model_transform: Transform, + batch_size: int, + collate_fn: Optional[Callable[[Any], Any]] = None, + drop_last: bool = True, + num_workers: int = 0, + parallel_method: Literal["process", "thread"] = "thread", + prefetch_factor: Optional[int] = 4, + pin_memory: bool = False, +) -> Loader: + """ + This will configure TorchData Nodes to approximate torch.utils.data.DataLoader. + Given a dataset, apply model_transform (eg tokenization), batching, collation, + memory pinning, and pre-fetching. + + Args: + dataset (DatasetType): dataset to load. May be a MultiNodeWeightedSampler + model_transform (Transform): model transform to apply to the samples of the dataset + batch_size (int): batch size + collate_fn (Optional[Callable[[Any], Any]]): collate function to apply to the samples of the dataset. If None, use + torch.utils.data.default_collate. Default None. + drop_last (bool): whether to drop the last batch. Default is True. + num_workers (int): number of workers to use for loading the dataset. Default is 0 (no parallelism + parallel_method (Literal["process", "thread"]): method to use for parallelism. Default is "thread". + prefetch_factor (Optional[int]): number of batches to prefetch. Default is 4. + pin_memory (bool): whether to pin memory. Default is False. + + Returns: + A ``torchdata.nodes`` Loader, an Iterable that returns batches. + """ + + from torchdata.nodes import Batcher, ParallelMapper, PinMemory, Prefetcher + + if collate_fn is None: + collate_fn = default_collate + + node = ParallelMapper( + dataset, map_fn=model_transform, num_workers=num_workers, method=parallel_method + ) + node = Batcher(node, batch_size, drop_last=drop_last) + node = ParallelMapper( + node, map_fn=collate_fn, num_workers=num_workers, method=parallel_method + ) + if pin_memory: + node = PinMemory(node) + if prefetch_factor is not None: + node = Prefetcher(node, prefetch_factor) + + return Loader(node) diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index e169cf70cd..f186974af9 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -7,11 +7,12 @@ from typing import Any, Callable, Dict, Mapping, Optional import numpy as np - from datasets import load_dataset from torch.utils.data import Dataset + from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages + from torchtune.modules.transforms import Transform @@ -110,6 +111,11 @@ def __init__( if filter_fn is not None: self._data = self._data.filter(filter_fn) + self._prepare_sample = SFTTransform( + message_transform=self._message_transform, + model_transform=self._model_transform, + ) + def __len__(self): return len(self._data) @@ -117,29 +123,49 @@ def __getitem__(self, index: int) -> Dict[str, Any]: sample = self._data[index] return self._prepare_sample(sample) - def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: - transformed_sample = self._message_transform(sample) - if "messages" in transformed_sample: - validate_messages(transformed_sample["messages"]) - - tokenized_dict = self._model_transform(transformed_sample) - if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): - keys_str = ", ".join(tokenized_dict.keys()) - error_message = ( - "model_transform returned the following keys: " - f"{keys_str}. Must return 'tokens' and 'mask' as keys." +class SFTTransform(Transform): + def __init__( + self, + message_transform: Optional[Transform] = None, + model_transform: Optional[Transform] = None, + ): + if message_transform is None and model_transform is None: + raise ValueError( + "At least one of message_transform or model_transform must be provided." ) - raise ValueError(error_message) - - # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens - tokenized_dict["labels"] = list( - np.where( - tokenized_dict["mask"], - CROSS_ENTROPY_IGNORE_IDX, - tokenized_dict["tokens"], + self._message_transform = message_transform + self._model_transform = model_transform + + def __call__(self, sample: Mapping[str, Any]) -> Dict[str, Any]: + if self._message_transform is not None: + transformed_sample = self._message_transform(sample) + if "messages" in transformed_sample: + validate_messages(transformed_sample["messages"]) + else: + transformed_sample = sample + + if self._model_transform is not None: + tokenized_dict = self._model_transform(transformed_sample) + + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): + keys_str = ", ".join(tokenized_dict.keys()) + error_message = ( + "model_transform returned the following keys: " + f"{keys_str}. Must return 'tokens' and 'mask' as keys." + ) + raise ValueError(error_message) + + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"], + ) ) - ) - assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + else: + tokenized_dict = transformed_sample return tokenized_dict diff --git a/torchtune/datasets/multimodal/__init__.py b/torchtune/datasets/multimodal/__init__.py index 85572d3c3a..9efad1e730 100644 --- a/torchtune/datasets/multimodal/__init__.py +++ b/torchtune/datasets/multimodal/__init__.py @@ -6,11 +6,12 @@ from ._llava_instruct import llava_instruct_dataset from ._multimodal import multimodal_chat_dataset -from ._the_cauldron import the_cauldron_dataset +from ._the_cauldron import the_cauldron_dataset, the_cauldron_transform from ._vqa import vqa_dataset __all__ = [ "the_cauldron_dataset", + "the_cauldron_transform", "llava_instruct_dataset", "multimodal_chat_dataset", "vqa_dataset", diff --git a/torchtune/datasets/multimodal/_the_cauldron.py b/torchtune/datasets/multimodal/_the_cauldron.py index 8887edf827..c5712bbbc8 100644 --- a/torchtune/datasets/multimodal/_the_cauldron.py +++ b/torchtune/datasets/multimodal/_the_cauldron.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Mapping, Optional from torchtune.data._messages import Message -from torchtune.datasets._sft import SFTDataset +from torchtune.datasets._sft import SFTDataset, SFTTransform from torchtune.modules.transforms import Transform @@ -235,3 +235,47 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: ) return ds + + +def the_cauldron_transform( + model_transform: Optional[Transform] = None, + texts_col: str = "texts", + images_col: str = "images", + new_system_prompt: Optional[str] = None, +) -> SFTTransform: + """ + Support for family of image + text datasets similar to + `The Cauldron `_ + from Hugging Face Datasets. + + This function instantiates a :class:`~torchtune.datasets.SFTTransform` only (not the dataset). + See :func:`~torchtune.datasets.the_cauldron_dataset` for more details. + + The model transform is expected to be a callable that applies pre-processing steps specific + to a model. For multimodal datasets, this is expected to be at minimum a tokenizer and + an image transform. The tokenizer will convert text sequences into token IDs after the dataset + is converted to a list of :class:`~torchtune.data.Message`. The image transform will load the + image and process it in accordance to the model's requirements. + + Args: + model_transform (Optional[Transform]): model-specific transform class that takes in a sample dict and applies custom + transforms on the keys. It should consist of at minimum two components: text tokenization (called + on the "messages" field) and image transform (called on the "images" field). The keys returned by + the model transform should be aligned with the expected inputs into the model. Default is None. + texts_col (str): name of the column containing the text data. Default is "texts". + images_col (str): name of the column containing the image data. Default is "images". + new_system_prompt (Optional[str]): if specified, prepend a system message. This can + serve as instructions to guide the model response. Setting this will OVERRIDE any system + messages already present in the dataset. Default is None. + + Returns: + :class:`~torchtune.datasets.SFTTransform` - Callable that transforms samples into The Cauldron format. + """ + column_map = {"texts": texts_col, "images": images_col} + return SFTTransform( + message_transform=TheCauldronToMessages( + column_map=column_map, + new_system_prompt=new_system_prompt, + ), + model_transform=model_transform, + ) diff --git a/torchtune/utils/_import_guard.py b/torchtune/utils/_import_guard.py index 02625945d4..582c790225 100644 --- a/torchtune/utils/_import_guard.py +++ b/torchtune/utils/_import_guard.py @@ -4,9 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import importlib + import torch # We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above _SUPPORTS_FLEX_ATTENTION = ( torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5) ) + +_TORCHDATA_MIN_VERSION = "0.10.0" +if ( + importlib.util.find_spec("torchdata") is not None + and importlib.util.find_spec("torchdata.nodes") is not None +): + _TORCHDATA_INSTALLED = True +else: + _TORCHDATA_INSTALLED = False