diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index bea507991d..3da5e57848 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -11,7 +11,7 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from torch.utils.data import DistributedSampler +from torch.utils.data import default_collate, DistributedSampler from torchtune.data._torchdata import DatasetType, Loader, requires_torchdata from torchtune.modules.transforms import Transform @@ -272,7 +272,7 @@ def get_dataloader( dataset: DatasetType, model_transform: Transform, batch_size: int, - collate_fn: Callable[[Any], Any], + collate_fn: Optional[Callable[[Any], Any]] = None, packed: bool = False, drop_last: bool = True, num_workers: int = 0, @@ -302,6 +302,9 @@ def get_dataloader( from torchdata.nodes import Batcher, ParallelMapper, PinMemory, Prefetcher + if collate_fn is None: + collate_fn = default_collate + node = ParallelMapper( dataset, map_fn=model_transform, num_workers=num_workers, method=parallel_method )