Skip to content

Commit

Permalink
fix all pre-commit lints
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Dec 12, 2024
1 parent 41be6b0 commit b9f7303
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
45 changes: 28 additions & 17 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,12 @@ def format_content_with_images(
return final_content_list


def chain(*funcs: Callable) -> Callable:
def chain(*funcs: List[Callable]) -> Callable:
"""
Chain a list of functions together into a single function.
Args:
funcs (List[Callable]): list of functions to chain together
*funcs (List[Callable]): list of functions to chain together
Returns:
Callable: chained function
Expand Down Expand Up @@ -187,18 +187,23 @@ def load_hf_dataset(
Load a HuggingFace dataset (Map or Streaming) and apply a Transform to it.
Args:
source (str): HuggingFace dataset source
transform (Transform): Transform to apply to the samples of the dataset
filter_fn (Optional[Callable]): Filter function to pass to HuggingFace dataset
source (str): HuggingFace dataset source.
transform (Transform): Transform to apply to the samples of the dataset.
filter_fn (Optional[Callable]): Filter function to pass to HuggingFace dataset.
shuffle (bool): Whether to shuffle the dataset. Default is True. For streaming datasets, this is passed to
HuggingFace dataset as .shuffle(). For map datasets, a DistributedSampler is used.
seed (int): Seed for the random number generator in the case of Map style dataset shuffling. Default is 0.
num_workers (int): Number of workers to use for loading the dataset. Default is 0 (no parallelism). Setting this
greater than 0 will create `parallel_method` workers to perform transforms to the dataset
greater than 0 will create `parallel_method` workers to perform transforms to the dataset.
parallel_method (Literal["process", "thread"]): Method to use for parallelism. Default is "thread". No effect if
num_workers is 0
load_dataset_kwargs (Dict[str, Any]): Additional Keyword arguments to pass to HuggingFace dataset. See Hugging Face's
num_workers is 0.
streaming (bool): whether to load a streaming vs map-style dataset. Default False.
**load_dataset_kwargs (Dict[str, Any]): Additional Keyword arguments to pass to HuggingFace dataset. See Hugging Face's
documentation.
Returns:
A ``torchdata.nodes`` iterator that can be passed directly to a Loader, or combined with other-datasets in a multi-dataset
sampler.
"""
from torchdata.nodes import IterableWrapper, ParallelMapper, SamplerWrapper

Expand Down Expand Up @@ -243,21 +248,24 @@ def load_hf_dataset(

@requires_torchdata
def get_multi_dataset(
datasets: dict[str, DatasetType],
weights: dict[str, float],
stop_criteria: str = "CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED",
datasets: Dict[str, DatasetType],
weights: Dict[str, float],
stop_criteria: str = "CYCLE_UNTIL_ALL_DATASETS_EXHASTED",
seed: int = 0,
) -> DatasetType:
"""
Given a dictionary of datasets and their corresponding weights, return a dataset that
samples from the given datasets according to the specified weights.
Args:
datasets (Dict[str, Any]): dictionary of datasets
weights (Optional[Dict[str, float]]): dictionary of weights for each dataset. If not
datasets (Dict[str, DatasetType]): dictionary of datasets
weights (Dict[str, float]): dictionary of weights for each dataset. If not
stop_criteria (str): stop criteria for the sampler. Default "CYCLE_UNTIL_ALL_DATASETS_EXHASTED".
see also: torchdata.nodes.StopCriteria
seed: (int): seed for the random number generator. Default 0.
See also: torchdata.nodes.StopCriteria
seed (int): seed for the random number generator. Default 0.
Returns:
A ``torchdata.nodes`` iterator which can be passed to Loader, or further composed with other Nodes.
"""
from torchdata.nodes import MultiNodeWeightedSampler

Expand Down Expand Up @@ -290,13 +298,16 @@ def get_dataloader(
dataset (DatasetType): dataset to load. May be a MultiNodeWeightedSampler
model_transform (Transform): model transform to apply to the samples of the dataset
batch_size (int): batch size
collate_fn (Callable[[Any], Any]): collate function to apply to the samples of the dataset
packed (bool): whether to pack the dataset. Default is False. Not supported yet.
collate_fn (Optional[Callable[[Any], Any]]): collate function to apply to the samples of the dataset. If None, use
torch.utils.data.default_collate. Default None.
drop_last (bool): whether to drop the last batch. Default is True.
num_workers (int): number of workers to use for loading the dataset. Default is 0 (no parallelism
parallel_method (Literal["process", "thread"]): method to use for parallelism. Default is "thread".
prefetch_factor (Optional[int]): number of batches to prefetch. Default is 4.
pin_memory (bool): whether to pin memory. Default is False.
Returns:
A ``torchdata.nodes`` Loader, an Iterable that returns batches.
"""

from torchdata.nodes import Batcher, ParallelMapper, PinMemory, Prefetcher
Expand Down
4 changes: 2 additions & 2 deletions torchtune/datasets/multimodal/_the_cauldron.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,10 @@ def the_cauldron_transform(
image and process it in accordance to the model's requirements.
Args:
model_transform (Transform): model-specific transform class that takes in a sample dict and applies custom
model_transform (Optional[Transform]): model-specific transform class that takes in a sample dict and applies custom
transforms on the keys. It should consist of at minimum two components: text tokenization (called
on the "messages" field) and image transform (called on the "images" field). The keys returned by
the model transform should be aligned with the expected inputs into the model.
the model transform should be aligned with the expected inputs into the model. Default is None.
texts_col (str): name of the column containing the text data. Default is "texts".
images_col (str): name of the column containing the image data. Default is "images".
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
Expand Down

0 comments on commit b9f7303

Please sign in to comment.