Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Dec 12, 2024
1 parent 5ed8382 commit 41be6b0
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 18 deletions.
9 changes: 3 additions & 6 deletions recipes/configs/llama3_2_vision/11B_lora_multi_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ datasets:

# Fine-tuning arguments
epochs: 1
# max_steps_per_epoch is required for progress bar
max_steps_per_epoch: 50
batch_size: 4
gradient_accumulation_steps: 1
Expand All @@ -114,12 +115,8 @@ dtype: bf16

# Logging
output_dir: /tmp/lora-llama3.2-vision-finetune
# metric_logger:
# _component_: torchtune.training.metric_logging.DiskLogger
# log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
metric_logger:
_component_: torchtune.training.metric_logging.StdoutLogger
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
log_every_n_steps: 1
log_peak_memory_stats: True

profile_mode: null # dataloader_only | model_only
10 changes: 5 additions & 5 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time

from functools import partial
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
from warnings import warn

import torch
Expand All @@ -19,11 +19,10 @@

from torch.optim import Optimizer

from torchdata.nodes import Loader, StopCriteria, T
from torchdata.nodes import Loader, StopCriteria
from torchtune import config, modules, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.data._torchdata import DatasetType
from torchtune.data._utils import get_dataloader, get_multi_dataset, load_hf_dataset
from torchtune.datasets._sft import SFTTransform
from torchtune.modules.peft import (
Expand Down Expand Up @@ -592,6 +591,9 @@ def _setup_data(
collate_fn = cfg_dataloader.collate_fn
prefetch_factor = cfg_dataloader.get("prefetch_factor", 6)

if packed:
raise ValueError("Packing not yet supported")

# Multi-Dataset Stop Criteria
stop_criteria = cfg_dataloader.get(
"stop_criteria", StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
Expand All @@ -602,7 +604,6 @@ def _setup_data(
if dataset_name is None:
dataset_name = cfg_dataset.get("subset", None)
key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "")
assert key not in weights, f"Duplicate dataset name {key}"

utils.log_rank_zero(log, f"Instantiating dataset {cfg_dataset}")
# Handle dataset-specific overrides, fallback to cfg_dataloader settings
Expand Down Expand Up @@ -652,7 +653,6 @@ def _setup_data(
model_transform=SFTTransform(model_transform=self._tokenizer),
batch_size=batch_size,
collate_fn=collate_fn,
packed=packed,
drop_last=True,
num_workers=num_workers,
parallel_method=parallel_method,
Expand Down
6 changes: 6 additions & 0 deletions torchtune/data/_torchdata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
from typing import Any, Callable, Iterable, Iterator, Mapping, TypeVar

Expand Down
13 changes: 6 additions & 7 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def format_content_with_images(
return final_content_list


def chain(*funcs: Callable):
def chain(*funcs: Callable) -> Callable:
"""
Chain a list of functions together into a single function.
Expand Down Expand Up @@ -180,6 +180,7 @@ def load_hf_dataset(
seed: int = 0,
num_workers: int = 0,
parallel_method: Literal["process", "thread"] = "thread",
streaming: bool = False,
**load_dataset_kwargs: Dict[str, Any],
) -> DatasetType:
"""
Expand All @@ -201,10 +202,10 @@ def load_hf_dataset(
"""
from torchdata.nodes import IterableWrapper, ParallelMapper, SamplerWrapper

# Need to lazy import to avoid circular dependency
# TODO: Remove lazy import when we can
# see: https://github.com/pytorch/torchtune/issues/2151
from torchtune.training._distributed import get_world_size_and_rank

streaming = load_dataset_kwargs.get("streaming", False)
if "subset" in load_dataset_kwargs:
assert (
"name" not in load_dataset_kwargs
Expand All @@ -228,6 +229,8 @@ def load_hf_dataset(
shuffle=shuffle,
seed=seed,
)
# Note: SamplerWrapper will call set_epoch on the sampler (if defined),
# and auto-increment the epoch each time the node is reset.
node = SamplerWrapper(sampler)
transform = chain(dataset.__getitem__, transform) # type: ignore

Expand Down Expand Up @@ -255,7 +258,6 @@ def get_multi_dataset(
stop_criteria (str): stop criteria for the sampler. Default "CYCLE_UNTIL_ALL_DATASETS_EXHASTED".
see also: torchdata.nodes.StopCriteria
seed: (int): seed for the random number generator. Default 0.
"""
from torchdata.nodes import MultiNodeWeightedSampler

Expand All @@ -273,7 +275,6 @@ def get_dataloader(
model_transform: Transform,
batch_size: int,
collate_fn: Optional[Callable[[Any], Any]] = None,
packed: bool = False,
drop_last: bool = True,
num_workers: int = 0,
parallel_method: Literal["process", "thread"] = "thread",
Expand All @@ -297,8 +298,6 @@ def get_dataloader(
prefetch_factor (Optional[int]): number of batches to prefetch. Default is 4.
pin_memory (bool): whether to pin memory. Default is False.
"""
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")

from torchdata.nodes import Batcher, ParallelMapper, PinMemory, Prefetcher

Expand Down
28 changes: 28 additions & 0 deletions torchtune/datasets/multimodal/_the_cauldron.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,34 @@ def the_cauldron_transform(
images_col: str = "images",
new_system_prompt: Optional[str] = None,
) -> SFTTransform:
"""
Support for family of image + text datasets similar to
`The Cauldron <https://huggingface.co/datasets/HuggingFaceM4/the_cauldron>`_
from Hugging Face Datasets.
This function instantiates a :class:`~torchtune.datasets.SFTTransform` only (not the dataset).
See :func:`~torchtune.datasets.the_cauldron_dataset` for more details.
The model transform is expected to be a callable that applies pre-processing steps specific
to a model. For multimodal datasets, this is expected to be at minimum a tokenizer and
an image transform. The tokenizer will convert text sequences into token IDs after the dataset
is converted to a list of :class:`~torchtune.data.Message`. The image transform will load the
image and process it in accordance to the model's requirements.
Args:
model_transform (Transform): model-specific transform class that takes in a sample dict and applies custom
transforms on the keys. It should consist of at minimum two components: text tokenization (called
on the "messages" field) and image transform (called on the "images" field). The keys returned by
the model transform should be aligned with the expected inputs into the model.
texts_col (str): name of the column containing the text data. Default is "texts".
images_col (str): name of the column containing the image data. Default is "images".
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
Returns:
:class:`~torchtune.datasets.SFTTransform` - Callable that transforms samples into The Cauldron format.
"""
column_map = {"texts": texts_col, "images": images_col}
return SFTTransform(
message_transform=TheCauldronToMessages(
Expand Down

0 comments on commit 41be6b0

Please sign in to comment.