Skip to content

Commit

Permalink
allow null collate_fn in get_dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Nov 26, 2024
1 parent cb67f43 commit cacf2b8
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

from torch.utils.data import DistributedSampler
from torch.utils.data import default_collate, DistributedSampler

from torchtune.data._torchdata import DatasetType, Loader, requires_torchdata
from torchtune.modules.transforms import Transform
Expand Down Expand Up @@ -272,7 +272,7 @@ def get_dataloader(
dataset: DatasetType,
model_transform: Transform,
batch_size: int,
collate_fn: Callable[[Any], Any],
collate_fn: Optional[Callable[[Any], Any]] = None,
packed: bool = False,
drop_last: bool = True,
num_workers: int = 0,
Expand Down Expand Up @@ -302,6 +302,9 @@ def get_dataloader(

from torchdata.nodes import Batcher, ParallelMapper, PinMemory, Prefetcher

if collate_fn is None:
collate_fn = default_collate

node = ParallelMapper(
dataset, map_fn=model_transform, num_workers=num_workers, method=parallel_method
)
Expand Down

0 comments on commit cacf2b8

Please sign in to comment.