Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchdata integration - multi-dataset and streaming support #1929

Merged
merged 14 commits into from
Dec 16, 2024

Conversation

andrewkho
Copy link
Contributor

@andrewkho andrewkho commented Oct 30, 2024

Note! This requires torchdata nightly to be installed to work correctly.

Test multi-dataset training command:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 compile=True dataloader.num_workers=4

Test multi-dataset command with dataloader_only mode:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 profile_mode=dataloader_only compile=True dataloader.num_workers=4

Benchmarking on 8xA100

for 200 steps, full model training, batch_size: 4, gradient_accumulation_steps: 1
TWFB: Sum of Time Waiting For Batch (max across all ranks) divided by sum of step times
Single Datasets are run with OCRVQA
Multi-Dataset was run with:
ocrvqa, docvqa, dvqa, tabmwp with equal weighting for all datasets.

Trial TWFB % Sum DL times sum step times
0 Single dataset (Baseline) 0.06279 14.596 232.454
1 Single dataset (TorchData threads) 0.00726 1.64746 226.780
2 Multi dataset (threads) 0.00514 2.4945 485.388
3 Multi dataset Streaming 0.01712, 8.5445 499.0716

Multi-dataset runs much slower, guessing because one of the datasets (dvqa?) requires more padding than ocrvqa.

Launch commands:
0: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=false max_steps_per_epoch=200 compile=True
1: (same as 2 but with 3 of the datasets removed in the config)
2: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True
3: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True dataloader.streaming=true

Please have a look at the code set up, and how this composability can help with streaming datasets, and multi-dataset mixing. You can think of this as approaches to replace torch.utils.data.DataLoader, while introducing more flexible parallelism schemes, eg instead of just one multiprocess worker setting, you could do multi-threading, pipeline parallelism, etc. It also enables more powerful composability IMO.

I have done some single-device benchmarking on a machine with A100 40gb, both with and without the model, performance is on par or better than standard dataloader.

TODO: fill in below

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1929

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 139d7a7 with merge base c2c6f4a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 30, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this prototype, this is helpful to see! The way I'd imagine we'd expose the torchdata dataloader would be from a builder function with a few knobs exposed with reasonable defaults:

def build_dataloader(ds, num_workers, pin_memory, prefetch, in_memory, parallel_method, ...):
	# Put together all the nodes here

# In config
dataloader:
  num_workers:
  ...

For a power user, what might they want to tune to optimize performance for their hardware and model setup?

It's also not clear to me how some media transforms/decoding might get optimized, is that just handled by the torchdata nodes automatically?

node = IterableWrapper(sampler)
node = _Mapper(node, map_fn=ds._data.__getitem__)
# Cut here for Streaming/Iterable dataset instead =====
node = _Mapper(node, map_fn=ds._prepare_sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this where the transform would get parallelized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see

_Mapper = partial(
    ParallelMapper,
    num_workers=num_workers,
    method=parallel_method,
    in_order=True,
)

)
# Map style set up =======
node = IterableWrapper(sampler)
node = _Mapper(node, map_fn=ds._data.__getitem__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does this mean we can keep our own Dataset abstractions? What if we went with entirely Iterable datasets?

Copy link
Contributor Author

@andrewkho andrewkho Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, but this was just a way to get it working quickly, we can also refactor this into separate function. For IterableDataset, you could wrap it in IterableWrapper and then call everything underneath here, ie

node = IterableWrapper(my_iterable_dataset)
node = _Mapper(node, map_fn=ds._prepare_sample)
...

batch = next(dl_iter)
dl_dt = time.perf_counter() - dl_t0
idx += 1
except StopIteration:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not create a context manager that handles this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No good reason, hacky code is hacky :) Just wanted to see some rough numbers

# Map style set up =======
node = IterableWrapper(sampler)
node = _Mapper(node, map_fn=ds._data.__getitem__)
# Cut here for Streaming/Iterable dataset instead =====
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we tried this out on an HF dataset with streaming = True yet? (I assume it won't work out of the box yet?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see update :D almost out-of-the-box

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to wrap with IterableWrapper if the underlying dataset is IterableDataset? Also in our case, the underlying dataset would be a HF dataset class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapper will make it conform to BaseNode's API so we can hold a pointer to iterator, as well as unify the state management. Subclasses need to define .iterator() instead of .__iter__().

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating this PR! I like the IterableWrapper/Batcher/etc APIs, they look pretty clean (and agree with @RdoubleA's suggestion about exposing this in builders/configs). How will this work when training on multiple datasets? Will we just apply the same set of APIs to each sampler individually?

@andrewkho
Copy link
Contributor Author

Thanks for the comments y'all, I updated this with a streaming example. Test with:

tune run lora_finetune_single_device --config llama3_2_vision/11B_lora_single_device num_workers=2 pin_memory=False use_torchdata=True parallel_method=thread max_steps_per_epoch=50 dataset.streaming=True

@andrewkho
Copy link
Contributor Author

@RdoubleA

For a power user, what might they want to tune to optimize performance for their hardware and model setup?

Good question, at the minimum it'd be some global worker-setting, but one option that may be worth supporting is allowing power users to define their entire pipelines entirely in config, not sure if you think this is a bad idea. eg by default, use a builder with a few knobs, but also allow the entire dataloader definition to be composable. Somewhat similar to how you enable users to pass a list of datasets right now. Something similar could be done for mixing. We'd need to be thoughtful for what the useful small-atomic units woudl be, and figure out syntax sugar.

It's also not clear to me how some media transforms/decoding might get optimized, is that just handled by the torchdata nodes automatically?

In terms of optimizing, one thing we're looking at is tf.data's autotune approach which will automatically update prefetch buffers and workers. This is not implemented, but hoping to do something along these lines which will help with the too-many-tunable-parameters problem.

@@ -235,3 +243,73 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")
return ds


def the_cauldron_dataset_torchdata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much of this could be pushed into the SFT class so that we can just reuse it for any new datasets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of it is probably re-usable, I think it should be it's own class or at least a builder, maybe like _sft.py: class SFTDatasetNode or something less terrible sounding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually 90% of this builder func could probably live in _sft.py as it doesn't have anything to do with the cauldron

@@ -0,0 +1,1119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor Author

@andrewkho andrewkho Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied and modified from lora_finetune_distributed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to _multi_dataset

@@ -0,0 +1,131 @@
# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied and modified from 11B_lora

@@ -117,7 +128,17 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._data[index]
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:

class PrepareSample:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring this into it's own Callable class that can be used by both the current torch.utils.Dataset and the torchdata.nodes

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey Andrew, thank you so much for this PR! This is such a nice feature to have.

I did a first pass. I understand that its still a draft, but thought of making comments so it could save you some time for when the PR is closer to being ready.

Comment on lines 116 to 117
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use our utility instead?

def get_world_size_and_rank() -> Tuple[int, int]:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be removed

global_streaming=streaming,
global_num_workers=num_workers,
),
prefetch_factor=8,
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to add it to setup_data as a default or exposed in the config? Or is this parameter not commonly changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR: this is now exposed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, why isn't this showing the new version? Let me make sure I've pushed


return ds

def _setup_data_td(
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require a bit of a refactoring, and probably not the main point of this pr, but it would be nice to have _setup_dataset and _setup_dataloader as two different methods. It should be easier to read and maintain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've split this into setting up individual datasets (_setup_one_dataset) and the global dataloader/mixer set up.


# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this was already in the setup_data, but we usually try to fail fast and catch errors like this in the init.

Comment on lines 113 to 114
if load_dataset_kwargs.get("streaming", False):
self._data = split_dataset_by_node(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if we can decouple dataloader from dataset, it will be easier to maintain/work with it. For example, can we do something like:

MyDatasetDistributed = split_dataset_by_node(MyDataset)

def split_dataset_by_node(...):
	assert hasattr(MyDataset, self._data)

or maybe SFTDataset can have getter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed actually

Comment on lines 690 to 697
if len(cfg_datasets) == 1:
node = next(iter(datasets.values())) # TODO: multi dataset
else:
node = MultiDatasetWeightedSampler(
source_nodes=datasets,
weights=weights,
)
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: I understand that this is torchdata api, but i wonder if 'len(cfg_datasets) == 1' check should be inside of MultiDatasetWeightedSampler. I.e. do we need this if check?


log.info("TorchData nodes are initialized")

return node
Copy link
Contributor

@felipemello1 felipemello1 Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'node' feels a bit weird, since we do: self._dataloader = self._setup_dataloader(...)

Should we rename it to dataloader?


# TODO: add multi-dataset mixer
if num_workers == 0:
_Mapper = Mapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit torn about if getting mapper, sampler, pin memory, etc should be an utility shared across all recipes, or if it should be exposed. No strong opinion, just thinking outloud

if pin_memory:
node = PinMemory(node)
if num_workers > 0:
node = Prefetcher(node, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this PreFetcher=2 different than the previous prefect_factor=8?

Comment on lines 645 to 647
All data related setup happens here. Currently this recipe only supports
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess this needs to be updated

prefetch_factor: 2
seed: null

multi_datasets:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally this is too verbose to parse for me, and even in the recipe there are just too many nested dictionaries. Ideally, I would like to achieve this type of UI in the config for datasets:

datasets:
  - _component_: torchtune.datasets...
    weight: 1.0
    subset: ...
  - _component_: torchtune.datasets...
    weight: 1.0
    ...

or something similar so all I have to do is specify the dataset I want and the weight. As it is I have multi_datasets -> datasets -> dataset just to specify the dataset builder. Maybe this is very ideal, but other libraries such as Axolotl are able to do this.

I am aware there's a few challenges to having this:

  1. MultiDatasetSampler requires passing in dictionaries for datasets and weights
  2. weight is not a valid argument for instantiating dataset components

I'm wondering if there's a way we can do this all for the user in a builder. For example:

datasets: ListConfig
for cfg_dataset in datasets:
	weights[k] = cfg_dataset.pop("weight")
	dataset[k] = Prefetcher(config.instantiate(cfg_dataset), prefetch_factor)
dataloader = get_multi_datasets(datasets, weights, cfg_dataloader)

stop_criterion imo should be moved to the dataloader config

)
weights, datasets = {}, {}
cfg_datasets = cfg_multi_datasets.datasets
for k, cfg_and_weight in cfg_datasets.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this logic needs to be in the builder. we do not want to expose torchdata internals in each recipe. I am totally okay with creating a new file in torchtune/data that contains functions that set up torchdata dataloaders and nodes.

Another ideal that I'm curious if we can achieve, can we unify the UX for multi datasets and single datasets? i.e., if we had a get_dataloader method, you can pass in a single dataset or a multi dataset and the call is the same in the recipe regardless of what the user specifies in the config

),
)

return node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this returns a node, not a dataloader right? Are users still able to access the underlying hugging face data?



@requires_torchdata
def SFTDatasetNode( # noqa[N802]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we're moving away from the class-based dataset abstraction and more of a function that returns the node configured with user parameters.

Curious if it would be better UX to just abandon the SFTDataset class (after the full migration) and keep Transform classes for each kinda of dataset (Instruct, Chat, MM, SFT, Preference) which is passed into a generic node builder

dataset = dataset.shuffle(seed=seed)
node = IterableWrapper(dataset)
else:
sampler = DistributedSampler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, which of these nodes are specific to a single dataset vs global for the dataloader?

streaming: bool = False,
shuffle: bool = False,
seed: int = 0,
num_workers: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is specific params such as seed, shuffle, num_workers, etc for each individual dataset a valid use case? My understanding was you can specify this globally at the dataloader level


# Get global settings
shuffle = cfg_dataloader.shuffle
parallel_method = cfg_dataloader.get("parallel_method", "thread")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in general, I'm quite confused on which parameters belong in the "dataset" abstraction and which belong in the "dataloader" abstraction. As it is, it seems you are using these in both. I would prefer to make this distinction very clear, unless I am missing something you may need to configure per dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trouble is that there's no real distinction between datasets and dataloader with torchdata.nodes, they're all just Iterators that can be composed together. Currently it's exposed but we can introduce restrictions if you think it would be simpler

@andrewkho andrewkho force-pushed the andrewkh/torchdata-integration branch from b1b2ab6 to 51d4327 Compare November 26, 2024 01:37
from typing_extensions import TypeAlias # typing.TypeAlias is only in Python 3.10+


try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should probably go in torchtune/utils/_import_guard.py

error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
class SFTTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be the cleanest way to separate the old way and the torchdata way, but what are your thoughts on keeping the class as SFTDataset, but remove all the HF load dataset logic out? we can keep source as a class attribute but not load it, then in the recipe or a utility used in the recipe: if we are using torchdata, use the load_hf_dataset you added here. else, use a barebones load utility with the original behavior.

When we eventually migrate to torchdata, we can just remove this switch logic and all the API names will remain the same but everything is swapped out under the hood

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is possible, but the tricky part is that right now, SFTDataset is a subclass of torch.utils.data.Dataset and the new version is a subclass of torchdata.nodes.BaseNode, and reconciling them may be difficult. We wouldn't be able to just make a switch because the torch.utils.DataLoader needs to be swapped out as well for the new get_loader builder wherever that is being called

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pbontrager for thoughts

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 10.28226% with 445 lines in your changes missing coverage. Please review.

Project coverage is 66.62%. Comparing base (f3d8d3c) to head (ef1399a).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
recipes/lora_finetune_distributed_td.py 0.00% 396 Missing ⚠️
torchtune/data/_utils.py 25.49% 38 Missing ⚠️
torchtune/data/_torchdata.py 69.56% 7 Missing ⚠️
torchtune/datasets/_sft.py 86.36% 3 Missing ⚠️
torchtune/datasets/multimodal/_the_cauldron.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #1929       +/-   ##
==========================================
+ Coverage   7.35%   66.62%   +59.26%     
==========================================
  Files        277      325       +48     
  Lines      15422    18444     +3022     
==========================================
+ Hits        1135    12289    +11154     
+ Misses     14287     6155     -8132     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -235,3 +235,17 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")
return ds


def the_cauldron_transform(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely love this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, we may not need the "column_map" b/c if it's truly just for the cauldron, then we know what the columns will be, no? It would be much clearer IMO to have e.g. image_column, text_column, etc.

num_tokens = 0

loader: Iterable
if self.profile_mode == "model_only":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very noob question but are all these profiling changes purely debug code? Or will we actually need to make changes to how we expose the profiler in order to be able to disentangle time spent on dataloading vs time spent in the model now?

Comment on lines 187 to 205
# Need to lazy import to avoid circular dependency
from torchtune.training._distributed import get_world_size_and_rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minor point but we should avoid doing this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazy importing? I think the workaround would be to up-level torchtune.training._distributed to eg torchtune.distributed. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I missed this. Yeah it's literally just an if-else so can always copy-paste. Personally I slightly prefer that to lazy import since it feels hacky to do this from within the same library. But no strong preference here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi our hero @joecummings addressed this in #2155. You can now do from torchtune.utils import get_world_size_and_rank at the top of the file


loader = get_dataloader(
dataset=dataset,
model_transform=SFTTransform(model_transform=self._tokenizer),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with this line? I thought we already pass the transform in load_hf_dataset. And why is it just the tokenizer instead of a multimodal transform?

@@ -10,3 +10,12 @@
_SUPPORTS_FLEX_ATTENTION = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5)
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "")
assert key not in weights, f"Duplicate dataset name {key}"
weights[key] = float(cfg_dataset.pop("weight"))
datasets[key] = self._setup_one_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From call: Let's inline this

Comment on lines 968 to 970
def iter_timer(
iterable: Iterable[T], barrier: bool
) -> Iterator[Tuple[T, float, float, float]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's drop this entirely

@andrewkho andrewkho force-pushed the andrewkh/torchdata-integration branch from cacf2b8 to 5ed8382 Compare December 11, 2024 00:22
@andrewkho andrewkho changed the title [draft] torchdata integration torchdata integration - multi-dataset and streaming support Dec 11, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience on the review here! I left a handful of pretty minor comments, but this is looking great. We should also make sure to get CI passing, but after that this looks good to me.

texts_col: str = "texts",
images_col: str = "images",
new_system_prompt: Optional[str] = None,
) -> SFTTransform:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we're getting close to landing i am gonna play nit 👑 . Let's add a docstring here since it's a public API


# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: 50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: this now becomes mandatory, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right, is this an issue for any users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking at the code, this is only required for progress bar calculation. Depending on stop_criteria (ie default is cycle-until-all-datasets-exhausted) it's non-trivial to know a-priori the total number of steps we'll get before StopIteration is thrown due to randomness. If we don't care about progress bar we can drop this requirement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, thought I responded. It's not an issue, but we need to make sure to not break anyone when we move this into our core recipes (since we currently default max_steps_per_epoch=null everywhere). Progress bar I don't care about too much, though we do also need it for learning rate scheduler, right? But yeah I agree with your point that in general we can no longer just infer it in general. Not a blocker for this PR anyways

log_every_n_steps: 1
log_peak_memory_stats: True

profile_mode: null # dataloader_only | model_only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

distributed training and can be run on a single node (1 to 8 GPUs).

Features:
- TorchData. Map and Streaming HuggingFace datasets, and multi-dataset mixing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but I feel like you're not selling the features very much here 😛. I think it's OK for now, but eventually we should have a docs page or something we can point to explaining the different features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2156 To track issue

torchtune/data/_utils.py Show resolved Hide resolved
# Need to lazy import to avoid circular dependency
from torchtune.training._distributed import get_world_size_and_rank

streaming = load_dataset_kwargs.get("streaming", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this isn't an explicit argument? Seems important enough to be worth exposing directly

recipes/lora_finetune_distributed_multi_dataset.py Outdated Show resolved Hide resolved
recipes/lora_finetune_distributed_multi_dataset.py Outdated Show resolved Hide resolved
torchtune/utils/_import_guard.py Outdated Show resolved Hide resolved
Comment on lines 300 to 301
if packed:
raise ValueError("Multimodal datasets don't support packing yet.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately this isn't the right place for this check, right? I assume we will need to do it somewhere else (or at least in another way) once we onboard text datasets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, I'm assuming that packer behaviour is going to need to be configurable, so we'll probably need to change this signature entirely anyways. I'll just drop it as an option and catch this in lora_finetune_distributed_multi_dataset for now, wdyt? @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that works. We can revisit when we're migrating text recipes

@joecummings joecummings added the triage review This issue should be discussed in weekly review label Dec 13, 2024
@joecummings joecummings removed the triage review This issue should be discussed in weekly review label Dec 13, 2024
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you'll need this now that #2006 has landed

Suggested change
resume_from_checkpoint=self._resume_from_checkpoint,
should_load_recipe_state=self._resume_from_checkpoint,

m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From #2139. Sorry about all these merges, it's been a busy week on the repo

Suggested change
m.lora_a.to_empty(device=lora_device)
m.lora_b.to_empty(device=lora_device)
m.to_empty(device=lora_device)

Comment on lines 506 to 512
is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = True
m.initialize_dora_magnitude()
if is_dora:
load_dora_magnitudes(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more

Suggested change
is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = True
m.initialize_dora_magnitude()
if is_dora:
load_dora_magnitudes(model)
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()

"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = training.get_world_size_and_rank()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do this now

Suggested change
_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()


# TODO: Remove lazy import when we can
# see: https://github.com/pytorch/torchtune/issues/2151
from torchtune.training._distributed import get_world_size_and_rank
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized the comment I responded to was outdated, so just in case you miss it: can now use from torchtune.utils import get_world_size_and_rank at the top of the file

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK a few more small merges to do, after that I think we're good. Thanks so much for building out this integration! Having proper multi-dataset and streaming support (not to mention the potential for more performant multimodal dataloading) is gonna be huge for our users.

@andrewkho andrewkho force-pushed the andrewkh/torchdata-integration branch from 6f64f52 to 139d7a7 Compare December 13, 2024 23:35
@ebsmothers ebsmothers merged commit 9dae7f1 into pytorch:main Dec 16, 2024
17 checks passed
@andrewkho andrewkho deleted the andrewkh/torchdata-integration branch December 16, 2024 23:03
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 18, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

* guard ckpt imports (pytorch#2133)

Co-authored-by: Felipe Mello <[email protected]>

* [bug fix] add parents=True (pytorch#2136)

Co-authored-by: Felipe Mello <[email protected]>

* [bug fix] re-add model (pytorch#2135)

Co-authored-by: Felipe Mello <[email protected]>

* Update save sizes into GiB (pytorch#2143)

* [bug fix] remove config download when source is kaggle (pytorch#2144)

Co-authored-by: Felipe Mello <[email protected]>

* [fix] remove "with_suffix" (pytorch#2146)

Co-authored-by: Felipe Mello <[email protected]>

* DoRA fixes (pytorch#2139)



Co-authored-by: Mircea Mironenco <[email protected]>

* [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)

* Small readme, config updates (pytorch#2157)

* Using `FormattedCheckpointFiles` in configs (pytorch#2147)

* Move ``get_world_size_and_rank`` to utils (pytorch#2155)

* Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006)

Co-authored-by: Saurabh Mishra <[email protected]>

* torchdata integration - multi-dataset and streaming support (pytorch#1929)

* Allow higher version of lm-eval (pytorch#2165)

* Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167)

* [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164)

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
Co-authored-by: Mircea Mironenco <[email protected]>
Co-authored-by: salman <[email protected]>
Co-authored-by: Saurabh Mishra <[email protected]>
Co-authored-by: Saurabh Mishra <[email protected]>
Co-authored-by: Andrew Ho <[email protected]>
Co-authored-by: Eugen Hotaj <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants