From 51d4327778c7f87f7b10a1e6e631284882bdbd77 Mon Sep 17 00:00:00 2001 From: andrewkho Date: Mon, 25 Nov 2024 17:37:07 -0800 Subject: [PATCH] refactor --- .../configs/llama3_2_vision/11B_lora_td.yaml | 52 +++--- recipes/lora_finetune_distributed_td.py | 160 +++++++++--------- torchtune/data/_torchdata.py | 43 +++++ torchtune/data/_utils.py | 131 +++++++++++++- torchtune/datasets/_sft.py | 149 ++++++---------- torchtune/datasets/multimodal/__init__.py | 4 +- .../datasets/multimodal/_the_cauldron.py | 129 +++++++++----- 7 files changed, 420 insertions(+), 248 deletions(-) create mode 100644 torchtune/data/_torchdata.py diff --git a/recipes/configs/llama3_2_vision/11B_lora_td.yaml b/recipes/configs/llama3_2_vision/11B_lora_td.yaml index 8b28bc2dbf..985840a21f 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_td.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_td.yaml @@ -49,7 +49,7 @@ checkpointer: resume_from_checkpoint: False save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only. -# TorchData Dataset setup +# TorchData setup dataloader: shuffle: True collate_fn: torchtune.data.padded_collate_tiled_images_and_mask @@ -60,29 +60,31 @@ dataloader: prefetch_factor: 2 seed: null -multi_datasets: - stop_criterion: CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED - datasets: - ocrvqa: - weight: 1.0 - dataset: - _component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata - subset: ocrvqa - dvqa: - weight: 1.0 - dataset: - _component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata - subset: dvqa - docvqa: - weight: 1.0 - dataset: - _component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata - subset: docvqa - tabmwp: - weight: 1.0 - dataset: - _component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata - subset: tabmwp +datasets: + - source: HuggingFaceM4/the_cauldron + subset: ocrvqa + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + - source: HuggingFaceM4/the_cauldron + subset: dvqa + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + - source: HuggingFaceM4/the_cauldron + subset: docvqa + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 + - source: HuggingFaceM4/the_cauldron + subset: tabmwp + split: train + transform: + _component_: torchtune.datasets.multimodal.the_cauldron_transform + weight: 1.0 # torch.utils.data.DataLoader Dataset setup, single dataset only classic_dataloader: @@ -98,7 +100,7 @@ use_torchdata: true # Fine-tuning arguments epochs: 1 -max_steps_per_epoch: null +max_steps_per_epoch: 50 batch_size: 4 gradient_accumulation_steps: 1 optimizer: diff --git a/recipes/lora_finetune_distributed_td.py b/recipes/lora_finetune_distributed_td.py index c57755285c..33d5a71592 100644 --- a/recipes/lora_finetune_distributed_td.py +++ b/recipes/lora_finetune_distributed_td.py @@ -21,22 +21,15 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchdata.nodes import ( - BaseNode, - Batcher, - Loader, - Mapper, - MultiDatasetWeightedSampler, - ParallelMapper, - PinMemory, - Prefetcher, - T, -) -from torchdata.nodes.samplers.multi_dataset_weighted_sampler import StopCriteria +from torchdata.nodes import Loader, T +from torchdata.nodes.samplers.stop_criteria import StopCriteria from torchtune import config, modules, training, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import padded_collate_packed +from torchtune.data._torchdata import DatasetType +from torchtune.data._utils import get_dataloader, get_multi_dataset, load_hf_dataset from torchtune.datasets import ConcatDataset +from torchtune.datasets._sft import SFTTransform from torchtune.modules.peft import ( DoRALinear, get_adapter_params, @@ -320,7 +313,7 @@ def setup(self, cfg: DictConfig) -> None: if cfg.get("use_torchdata", True): self._dataloader = self._setup_data_td( cfg_dataloader=cfg.dataloader, - cfg_multi_datasets=cfg.multi_datasets, + cfg_datasets=cfg.datasets, batch_size=cfg.batch_size, ) else: @@ -618,31 +611,34 @@ def _setup_lr_scheduler( def _setup_one_dataset( self, cfg_dataset: DictConfig, + global_streaming: bool, global_shuffle: bool, global_parallel_method: str, - global_streaming: bool, global_num_workers: int, - ) -> BaseNode: + ) -> DatasetType: streaming = cfg_dataset.pop("streaming", global_streaming) - parallel_method = cfg_dataset.pop("parallel_method", global_parallel_method) shuffle = cfg_dataset.pop("shuffle", global_shuffle) + parallel_method = cfg_dataset.pop("parallel_method", global_parallel_method) num_workers = cfg_dataset.pop("num_workers", global_num_workers) - return config.instantiate( - cfg_dataset, - self._tokenizer, + # Instantiate dataset transform + assert "transform" in cfg_dataset, "transform must be specified in dataset" + transform = config.instantiate(cfg_dataset.pop("transform")) + + log.info(f"Instantiating dataset {cfg_dataset}") + return load_hf_dataset( + **cfg_dataset, + transform=transform, streaming=streaming, shuffle=shuffle, parallel_method=parallel_method, num_workers=num_workers, ) - return ds - def _setup_data_td( self, cfg_dataloader: DictConfig, - cfg_multi_datasets: DictConfig, + cfg_datasets: ListConfig, batch_size: int, ) -> Loader: """ @@ -650,8 +646,6 @@ def _setup_data_td( both Map and Streaming datasets (from HuggingFace datasets), and mixing multiple datasets (can be mix of Map and Streaming). """ - world_size, rank = training.get_world_size_and_rank() - # Get global settings shuffle = cfg_dataloader.shuffle parallel_method = cfg_dataloader.get("parallel_method", "thread") @@ -660,69 +654,66 @@ def _setup_data_td( num_workers = cfg_dataloader.get("num_workers", 0) pin_memory = cfg_dataloader.get("pin_memory", True) collate_fn = cfg_dataloader.collate_fn - prefetch_factor = cfg_dataloader.get("prefetch_factor", 2) + prefetch_factor = cfg_dataloader.get("prefetch_factor", 6) - stop_criterion = cfg_multi_datasets.get( - "stop_criterion", StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED + # Multi-Dataset Stop Criteria + stop_criteria = cfg_dataloader.get( + "stop_criteria", StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED ) weights, datasets = {}, {} - cfg_datasets = cfg_multi_datasets.datasets - for k, cfg_and_weight in cfg_datasets.items(): - weights[k] = float(cfg_and_weight.weight) - datasets[k] = Prefetcher( - self._setup_one_dataset( - cfg_dataset=cfg_and_weight.dataset, - global_shuffle=shuffle, - global_parallel_method=parallel_method, - global_streaming=streaming, - global_num_workers=num_workers, - ), - prefetch_factor=prefetch_factor, + for idx, cfg_dataset in enumerate(cfg_datasets): + dataset_name = cfg_dataset.pop("name", None) + if dataset_name is None: + dataset_name = cfg_dataset.get("subset", None) + key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "") + assert key not in weights, f"Duplicate dataset name {key}" + weights[key] = float(cfg_dataset.pop("weight")) + datasets[key] = self._setup_one_dataset( + cfg_dataset=cfg_dataset, + global_shuffle=shuffle, + global_parallel_method=parallel_method, + global_streaming=streaming, + global_num_workers=num_workers, ) # Instantiate collate_fn if "left_pad_sequence" in collate_fn: raise RuntimeError("left_pad_sequence collator is only for inference.") - collate_fn = _get_component_from_path(collate_fn) - # TODO: add multi-dataset mixer - if num_workers == 0: - _Mapper = Mapper # noqa[N806] - else: - _Mapper = partial( # noqa[N806] - ParallelMapper, - num_workers=num_workers, - method=parallel_method, + collate_fn = ( + partial( + _get_component_from_path(collate_fn), + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, ) - if len(cfg_datasets) == 1: - node = next(iter(datasets.values())) - else: - node = MultiDatasetWeightedSampler( - source_nodes=datasets, + if not packed + else padded_collate_packed + ) + if len(datasets) > 1: + dataset = get_multi_dataset( + datasets=datasets, weights=weights, - stop_criterion=stop_criterion, + stop_criteria=stop_criteria, ) - node = Batcher(node, batch_size, drop_last=True) - node = _Mapper( - node, - map_fn=( - partial( - collate_fn, # noqa - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else padded_collate_packed - ), + else: + dataset = next(iter(datasets.values())) + + loader = get_dataloader( + dataset=dataset, + model_transform=SFTTransform(model_transform=self._tokenizer), + batch_size=batch_size, + collate_fn=collate_fn, + packed=packed, + drop_last=True, + num_workers=num_workers, + parallel_method=parallel_method, + prefetch_factor=prefetch_factor, + pin_memory=pin_memory, ) - if pin_memory: - node = PinMemory(node) - if num_workers > 0: - node = Prefetcher(node, prefetch_factor) log.info("TorchData nodes are initialized") - return Loader(node) + return loader def _setup_data( self, @@ -801,24 +792,27 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs - if self._is_rank_zero: - log.info( - "Saving checkpoint. This may take some time. Retrieving full model state dict..." - ) - start = time.perf_counter() + utils.log_rank_zero( + log, + "Saving checkpoint. This may take some time. Retrieving full model state dict...", + ) + start = time.perf_counter() # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, self._is_rank_zero, device=self._device, - trainable_only=self._save_adapter_weights_only, ) - if self._is_rank_zero: - log.info( - f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" - ) + utils.log_rank_zero( + log, + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs", + ) if intermediate_checkpoint: if self._is_rank_zero: diff --git a/torchtune/data/_torchdata.py b/torchtune/data/_torchdata.py new file mode 100644 index 0000000000..a59a96311c --- /dev/null +++ b/torchtune/data/_torchdata.py @@ -0,0 +1,43 @@ +import functools +from typing import Any, Callable, Iterable, Iterator, Mapping + +from typing_extensions import TypeAlias # typing.TypeAlias is only in Python 3.10+ + + +try: + from torchdata.nodes import BaseNode, Loader # noqa + + _TORCHDATA_INSTALLED = True + DatasetType: TypeAlias = BaseNode[Mapping[str, Any]] # type: ignore +except ImportError as e: + # If we fail to import torchdata, define some stubs to make typechecker happy + _TORCHDATA_INSTALLED = False + DatasetType: TypeAlias = Iterator[Mapping[str, Any]] # type: ignore + + class Loader(Iterable): + def __init__(self, *args, **kwargs): + assert_torchdata_installed() + + +MIN_VERSION = "0.10.0" + + +def assert_torchdata_installed(): + if not _TORCHDATA_INSTALLED: + raise ImportError( + f"torchdata is not installed, or the current version is too old. " + f"Please (re-)install it with `pip install torchdata>={MIN_VERSION}`. " + ) + + +def requires_torchdata(func: Callable) -> Callable: + """ + Decorator to check if torchdata is installed and raise an ImportError if not. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + assert_torchdata_installed() + return func(*args, **kwargs) + + return wrapper diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index 832e1babca..8946f38348 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -4,10 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools from pathlib import Path -from typing import Any, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union from urllib import request +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node + +from torch.utils.data import DistributedSampler + +from torchtune.data._torchdata import DatasetType, Loader, requires_torchdata +from torchtune.modules.transforms import Transform + T = TypeVar("T", bound=type) @@ -142,3 +151,123 @@ def format_content_with_images( final_content_list.append({"type": "image", "content": images.pop(0)}) return final_content_list + + +@requires_torchdata +def load_hf_dataset( + source: str, + transform: Transform, + filter_fn: Optional[Callable] = None, + shuffle: bool = True, + seed: int = 0, + num_workers: int = 1, + parallel_method: Literal["process", "thread"] = "thread", + **load_dataset_kwargs: Dict[str, Any], +) -> DatasetType: + from torchdata.nodes import IterableWrapper, Mapper, ParallelMapper, SamplerWrapper + + # Need to lazy import to avoid circular dependency + from torchtune.training._distributed import get_world_size_and_rank + + streaming = load_dataset_kwargs.get("streaming", False) + if "subset" in load_dataset_kwargs: + assert ( + "name" not in load_dataset_kwargs + ), f"found both 'subset' and 'name' found, you may only specify one, {load_dataset_kwargs=}" + load_dataset_kwargs["name"] = load_dataset_kwargs.pop("subset") + dataset = load_dataset(source, **load_dataset_kwargs) + if filter_fn is not None: + dataset = dataset.filter(filter_fn) + + if num_workers == 0: + _Mapper = Mapper # type: ignore + else: + _Mapper = functools.partial( + ParallelMapper, # type: ignore + num_workers=num_workers, + method=parallel_method, + ) + world_size, rank = get_world_size_and_rank() + if streaming: + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + if shuffle: + dataset = dataset.shuffle(seed=seed) + node = IterableWrapper(dataset) # type: ignore + else: + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=seed, + ) + node = SamplerWrapper(sampler) # type: ignore + node = _Mapper(node, map_fn=dataset.__getitem__) + + node = _Mapper(node, transform) + + return node + + +@requires_torchdata +def get_multi_dataset( + datasets: dict[str, DatasetType], + weights: dict[str, float], + stop_criteria: str = "CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED", + seed: int = 0, +) -> DatasetType: + """ + Given a dictionary of datasets and their corresponding weights, return a dataset that + samples from the given datasets according to the specified weights. + + Args: + datasets (Dict[str, Any]): dictionary of datasets + weights (Optional[Dict[str, float]]): dictionary of weights for each dataset. If not + + """ + from torchdata.nodes import MultiNodeWeightedSampler + + return MultiNodeWeightedSampler( + source_nodes=datasets, + weights=weights, + stop_criteria=stop_criteria, + seed=seed, + ) + + +@requires_torchdata +def get_dataloader( + dataset: DatasetType, + model_transform: Transform, + batch_size: int, + collate_fn: Callable[[Any], Any], + packed: bool = False, + drop_last: bool = True, + num_workers: int = 0, + parallel_method: Literal["process", "thread"] = "thread", + prefetch_factor: Optional[int] = 4, + pin_memory: bool = False, +) -> Loader: + if packed: + raise ValueError("Multimodal datasets don't support packing yet.") + + from torchdata.nodes import Batcher, Mapper, ParallelMapper, PinMemory, Prefetcher + + if num_workers == 0: + _Mapper = Mapper # noqa[N806] + else: + _Mapper = functools.partial( # noqa[N806] + ParallelMapper, + num_workers=num_workers, + method=parallel_method, + ) + + node = _Mapper(dataset, map_fn=model_transform) + node = Batcher(node, batch_size, drop_last=drop_last) + node = _Mapper(node, map_fn=collate_fn) + if pin_memory: + node = PinMemory(node) + if prefetch_factor is not None: + node = Prefetcher(node, prefetch_factor) + + return Loader(node) diff --git a/torchtune/datasets/_sft.py b/torchtune/datasets/_sft.py index 3b0923cb91..117ca89990 100644 --- a/torchtune/datasets/_sft.py +++ b/torchtune/datasets/_sft.py @@ -4,46 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import functools from typing import Any, Callable, Dict, Literal, Mapping, Optional import numpy as np from datasets import load_dataset -from datasets.distributed import split_dataset_by_node -from torch.utils.data import Dataset, DistributedSampler +from torch.utils.data import Dataset from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages -from torchtune.modules.transforms import Transform -from torchtune.training._distributed import get_world_size_and_rank - -try: - import torchdata.nodes # noqa - - _TORCHDATA_INSTALLED = True -except ImportError as e: - _TORCHDATA_INSTALLED = False - - -def assert_torchdata_installed(): - if not _TORCHDATA_INSTALLED: - raise ImportError( - "torchdata is not installed, or the current version is too old. " - "Please (re-)install it with `pip install torchdata>=0.10.0`" - ) - -def requires_torchdata(func: Callable) -> Callable: - """ - Decorator to check if torchdata is installed and raise an ImportError if not. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): - assert_torchdata_installed() - return func(*args, **kwargs) - - return wrapper +from torchtune.data._torchdata import DatasetType +from torchtune.data._utils import load_hf_dataset +from torchtune.modules.transforms import Transform class SFTDataset(Dataset): @@ -141,7 +113,7 @@ def __init__( if filter_fn is not None: self._data = self._data.filter(filter_fn) - self._prepare_sample = PrepareSample( + self._prepare_sample = SFTTransform( message_transform=self._message_transform, model_transform=self._model_transform, ) @@ -154,44 +126,53 @@ def __getitem__(self, index: int) -> Dict[str, Any]: return self._prepare_sample(sample) -class PrepareSample: +class SFTTransform(Transform): def __init__( self, - message_transform: Transform, - model_transform: Transform, + message_transform: Optional[Transform] = None, + model_transform: Optional[Transform] = None, ): + if message_transform is None and model_transform is None: + raise ValueError( + "At least one of message_transform or model_transform must be provided." + ) self._message_transform = message_transform self._model_transform = model_transform def __call__(self, sample: Mapping[str, Any]) -> Dict[str, Any]: - transformed_sample = self._message_transform(sample) - if "messages" in transformed_sample: - validate_messages(transformed_sample["messages"]) - - tokenized_dict = self._model_transform(transformed_sample) - - if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): - keys_str = ", ".join(tokenized_dict.keys()) - error_message = ( - "model_transform returned the following keys: " - f"{keys_str}. Must return 'tokens' and 'mask' as keys." - ) - raise ValueError(error_message) - - # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens - tokenized_dict["labels"] = list( - np.where( - tokenized_dict["mask"], - CROSS_ENTROPY_IGNORE_IDX, - tokenized_dict["tokens"], + if self._message_transform is not None: + transformed_sample = self._message_transform(sample) + if "messages" in transformed_sample: + validate_messages(transformed_sample["messages"]) + else: + transformed_sample = sample + + if self._model_transform is not None: + tokenized_dict = self._model_transform(transformed_sample) + + if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): + keys_str = ", ".join(tokenized_dict.keys()) + error_message = ( + "model_transform returned the following keys: " + f"{keys_str}. Must return 'tokens' and 'mask' as keys." + ) + raise ValueError(error_message) + + # Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens + tokenized_dict["labels"] = list( + np.where( + tokenized_dict["mask"], + CROSS_ENTROPY_IGNORE_IDX, + tokenized_dict["tokens"], + ) ) - ) - assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"]) + else: + tokenized_dict = transformed_sample return tokenized_dict -@requires_torchdata def SFTDatasetNode( # noqa[N802] source: str, message_transform: Transform, @@ -202,47 +183,17 @@ def SFTDatasetNode( # noqa[N802] shuffle: bool = True, seed: int = 0, **load_dataset_kwargs: Dict[str, Any], -) -> "BaseNode[Mapping[str, Any]]": - - # Importing here to avoid "Possibly unbound" mypy errors - from torchdata.nodes import IterableWrapper, Mapper, ParallelMapper, SamplerWrapper - - streaming = load_dataset_kwargs.get("streaming", False) - dataset = load_dataset(source, **load_dataset_kwargs) - if filter_fn is not None: - dataset = dataset.filter(filter_fn) - - if num_workers == 0: - _Mapper = Mapper # noqa[N806] - else: - _Mapper = functools.partial( # noqa[N806] - ParallelMapper, - num_workers=num_workers, - method=parallel_method, - ) - world_size, rank = get_world_size_and_rank() - if streaming: - dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) - if shuffle: - dataset = dataset.shuffle(seed=seed) - node = IterableWrapper(dataset) - else: - sampler = DistributedSampler( - dataset, - num_replicas=world_size, - rank=rank, - shuffle=shuffle, - seed=seed, - ) - node = SamplerWrapper(sampler) - node = _Mapper(node, map_fn=dataset.__getitem__) - - node = _Mapper( - node, - PrepareSample( +) -> DatasetType: + return load_hf_dataset( + source=source, + transform=SFTTransform( message_transform=message_transform, model_transform=model_transform, ), + filter_fn=filter_fn, + num_workers=num_workers, + parallel_method=parallel_method, + shuffle=shuffle, + seed=seed, + **load_dataset_kwargs, ) - - return node diff --git a/torchtune/datasets/multimodal/__init__.py b/torchtune/datasets/multimodal/__init__.py index 395cdc71fb..9efad1e730 100644 --- a/torchtune/datasets/multimodal/__init__.py +++ b/torchtune/datasets/multimodal/__init__.py @@ -6,12 +6,12 @@ from ._llava_instruct import llava_instruct_dataset from ._multimodal import multimodal_chat_dataset -from ._the_cauldron import the_cauldron_dataset, the_cauldron_dataset_torchdata +from ._the_cauldron import the_cauldron_dataset, the_cauldron_transform from ._vqa import vqa_dataset __all__ = [ "the_cauldron_dataset", - "the_cauldron_dataset_torchdata", + "the_cauldron_transform", "llava_instruct_dataset", "multimodal_chat_dataset", "vqa_dataset", diff --git a/torchtune/datasets/multimodal/_the_cauldron.py b/torchtune/datasets/multimodal/_the_cauldron.py index 10bbadd06f..22da3c6374 100644 --- a/torchtune/datasets/multimodal/_the_cauldron.py +++ b/torchtune/datasets/multimodal/_the_cauldron.py @@ -7,7 +7,8 @@ from typing import Any, Callable, Dict, Literal, Mapping, Optional from torchtune.data._messages import Message -from torchtune.datasets._sft import requires_torchdata, SFTDataset, SFTDatasetNode +from torchtune.data._utils import load_hf_dataset +from torchtune.datasets._sft import SFTDataset, SFTTransform from torchtune.modules.transforms import Transform @@ -237,44 +238,96 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: return ds -@requires_torchdata -def the_cauldron_dataset_torchdata( - model_transform: Transform, - *, - subset: str, - source: str = "HuggingFaceM4/the_cauldron", +def the_cauldron_transform( + model_transform: Optional[Transform] = None, column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, - packed: bool = False, - filter_fn: Optional[Callable] = None, - split: str = "train", - streaming: bool = False, - shuffle: bool = False, - seed: int = 0, - num_workers: int = 0, - parallel_method: Literal["process", "thread"] = "thread", - **load_dataset_kwargs: Dict[str, Any], -): - if packed: - raise ValueError("Multimodal datasets don't support packing yet.") - - message_transform = TheCauldronToMessages( - column_map=column_map, - new_system_prompt=new_system_prompt, - ) - - return SFTDatasetNode( - source=source, - message_transform=message_transform, +) -> SFTTransform: + return SFTTransform( + message_transform=TheCauldronToMessages( + column_map=column_map, + new_system_prompt=new_system_prompt, + ), model_transform=model_transform, - filter_fn=filter_fn, - num_workers=num_workers, - parallel_method=parallel_method, - shuffle=shuffle, - seed=seed, - # dataset kwargs - name=subset, - split=split, - streaming=streaming, - **load_dataset_kwargs, ) + + +# def the_cauldron_dataset_torchdata( +# model_transform: Optional[Transform] = None, +# *, +# subset: str, +# source: str = "HuggingFaceM4/the_cauldron", +# column_map: Optional[Dict[str, str]] = None, +# new_system_prompt: Optional[str] = None, +# packed: bool = False, +# filter_fn: Optional[Callable] = None, +# split: str = "train", +# streaming: bool = False, +# shuffle: bool = False, +# seed: int = 0, +# num_workers: int = 0, +# parallel_method: Literal["process", "thread"] = "thread", +# **load_dataset_kwargs: Dict[str, Any], +# ): +# if packed: +# raise ValueError("Multimodal datasets don't support packing yet.") +# return load_hf_dataset( +# source=source, +# transform=the_cauldron_transform( +# model_transform=model_transform, +# column_map=column_map, +# new_system_prompt=new_system_prompt, +# ), +# filter_fn=filter_fn, +# shuffle=shuffle, +# seed=seed, +# num_workers=num_workers, +# parallel_method=parallel_method, +# # Additional load_dataset kwargs +# subset=subset, +# split=split, +# streaming=streaming, +# **load_dataset_kwargs, +# ) + + +# def the_cauldron_dataset_torchdata( +# model_transform: Transform, +# *, +# subset: str, +# source: str = "HuggingFaceM4/the_cauldron", +# column_map: Optional[Dict[str, str]] = None, +# new_system_prompt: Optional[str] = None, +# packed: bool = False, +# filter_fn: Optional[Callable] = None, +# split: str = "train", +# streaming: bool = False, +# shuffle: bool = False, +# seed: int = 0, +# num_workers: int = 0, +# parallel_method: Literal["process", "thread"] = "thread", +# **load_dataset_kwargs: Dict[str, Any], +# ): +# if packed: +# raise ValueError("Multimodal datasets don't support packing yet.") + +# message_transform = TheCauldronToMessages( +# column_map=column_map, +# new_system_prompt=new_system_prompt, +# ) + +# return SFTDatasetNode( +# source=source, +# message_transform=message_transform, +# model_transform=model_transform, +# filter_fn=filter_fn, +# num_workers=num_workers, +# parallel_method=parallel_method, +# shuffle=shuffle, +# seed=seed, +# # dataset kwargs +# name=subset, +# split=split, +# streaming=streaming, +# **load_dataset_kwargs, +# )