-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torchdata integration - multi-dataset and streaming support #1929
torchdata integration - multi-dataset and streaming support #1929
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1929
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 139d7a7 with merge base c2c6f4a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this prototype, this is helpful to see! The way I'd imagine we'd expose the torchdata dataloader would be from a builder function with a few knobs exposed with reasonable defaults:
def build_dataloader(ds, num_workers, pin_memory, prefetch, in_memory, parallel_method, ...):
# Put together all the nodes here
# In config
dataloader:
num_workers:
...
For a power user, what might they want to tune to optimize performance for their hardware and model setup?
It's also not clear to me how some media transforms/decoding might get optimized, is that just handled by the torchdata nodes automatically?
node = IterableWrapper(sampler) | ||
node = _Mapper(node, map_fn=ds._data.__getitem__) | ||
# Cut here for Streaming/Iterable dataset instead ===== | ||
node = _Mapper(node, map_fn=ds._prepare_sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this where the transform would get parallelized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see
_Mapper = partial(
ParallelMapper,
num_workers=num_workers,
method=parallel_method,
in_order=True,
)
) | ||
# Map style set up ======= | ||
node = IterableWrapper(sampler) | ||
node = _Mapper(node, map_fn=ds._data.__getitem__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So does this mean we can keep our own Dataset abstractions? What if we went with entirely Iterable datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically yes, but this was just a way to get it working quickly, we can also refactor this into separate function. For IterableDataset, you could wrap it in IterableWrapper and then call everything underneath here, ie
node = IterableWrapper(my_iterable_dataset)
node = _Mapper(node, map_fn=ds._prepare_sample)
...
batch = next(dl_iter) | ||
dl_dt = time.perf_counter() - dl_t0 | ||
idx += 1 | ||
except StopIteration: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not create a context manager that handles this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No good reason, hacky code is hacky :) Just wanted to see some rough numbers
# Map style set up ======= | ||
node = IterableWrapper(sampler) | ||
node = _Mapper(node, map_fn=ds._data.__getitem__) | ||
# Cut here for Streaming/Iterable dataset instead ===== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we tried this out on an HF dataset with streaming = True yet? (I assume it won't work out of the box yet?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see update :D almost out-of-the-box
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to wrap with IterableWrapper
if the underlying dataset is IterableDataset? Also in our case, the underlying dataset would be a HF dataset class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapper will make it conform to BaseNode's API so we can hold a pointer to iterator, as well as unify the state management. Subclasses need to define .iterator()
instead of .__iter__()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for creating this PR! I like the IterableWrapper/Batcher/etc APIs, they look pretty clean (and agree with @RdoubleA's suggestion about exposing this in builders/configs). How will this work when training on multiple datasets? Will we just apply the same set of APIs to each sampler individually?
Thanks for the comments y'all, I updated this with a streaming example. Test with:
|
Good question, at the minimum it'd be some global worker-setting, but one option that may be worth supporting is allowing power users to define their entire pipelines entirely in config, not sure if you think this is a bad idea. eg by default, use a builder with a few knobs, but also allow the entire dataloader definition to be composable. Somewhat similar to how you enable users to pass a list of datasets right now. Something similar could be done for mixing. We'd need to be thoughtful for what the useful small-atomic units woudl be, and figure out syntax sugar.
In terms of optimizing, one thing we're looking at is tf.data's autotune approach which will automatically update prefetch buffers and workers. This is not implemented, but hoping to do something along these lines which will help with the too-many-tunable-parameters problem. |
@@ -235,3 +243,73 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | |||
if packed: | |||
raise ValueError("Multimodal datasets don't support packing yet.") | |||
return ds | |||
|
|||
|
|||
def the_cauldron_dataset_torchdata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much of this could be pushed into the SFT class so that we can just reuse it for any new datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of it is probably re-usable, I think it should be it's own class or at least a builder, maybe like _sft.py: class SFTDatasetNode
or something less terrible sounding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually 90% of this builder func could probably live in _sft.py as it doesn't have anything to do with the cauldron
@@ -0,0 +1,1119 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied and modified from lora_finetune_distributed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to _multi_dataset
@@ -0,0 +1,131 @@ | |||
# Config for multi-device LoRA finetuning in lora_finetune_distributed_td.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied and modified from 11B_lora
torchtune/datasets/_sft.py
Outdated
@@ -117,7 +128,17 @@ def __getitem__(self, index: int) -> Dict[str, Any]: | |||
sample = self._data[index] | |||
return self._prepare_sample(sample) | |||
|
|||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]: | |||
|
|||
class PrepareSample: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring this into it's own Callable class that can be used by both the current torch.utils.Dataset and the torchdata.nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey Andrew, thank you so much for this PR! This is such a nice feature to have.
I did a first pass. I understand that its still a draft, but thought of making comments so it could save you some time for when the PR is closer to being ready.
torchtune/datasets/_sft.py
Outdated
rank=int(os.environ.get("RANK", 0)), | ||
world_size=int(os.environ.get("WORLD_SIZE", 1)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use our utility instead?
torchtune/torchtune/training/_distributed.py
Line 150 in 1eb7785
def get_world_size_and_rank() -> Tuple[int, int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be removed
global_streaming=streaming, | ||
global_num_workers=num_workers, | ||
), | ||
prefetch_factor=8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to add it to setup_data as a default or exposed in the config? Or is this parameter not commonly changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated PR: this is now exposed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, why isn't this showing the new version? Let me make sure I've pushed
|
||
return ds | ||
|
||
def _setup_data_td( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would require a bit of a refactoring, and probably not the main point of this pr, but it would be nice to have _setup_dataset and _setup_dataloader as two different methods. It should be easier to read and maintain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've split this into setting up individual datasets (_setup_one_dataset) and the global dataloader/mixer set up.
|
||
# Instantiate collate_fn | ||
if "left_pad_sequence" in collate_fn: | ||
raise RuntimeError("left_pad_sequence collator is only for inference.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that this was already in the setup_data, but we usually try to fail fast and catch errors like this in the init.
torchtune/datasets/_sft.py
Outdated
if load_dataset_kwargs.get("streaming", False): | ||
self._data = split_dataset_by_node( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that if we can decouple dataloader from dataset, it will be easier to maintain/work with it. For example, can we do something like:
MyDatasetDistributed = split_dataset_by_node(MyDataset)
def split_dataset_by_node(...):
assert hasattr(MyDataset, self._data)
or maybe SFTDataset can have getter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be removed actually
if len(cfg_datasets) == 1: | ||
node = next(iter(datasets.values())) # TODO: multi dataset | ||
else: | ||
node = MultiDatasetWeightedSampler( | ||
source_nodes=datasets, | ||
weights=weights, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: I understand that this is torchdata api, but i wonder if 'len(cfg_datasets) == 1' check should be inside of MultiDatasetWeightedSampler. I.e. do we need this if check?
|
||
log.info("TorchData nodes are initialized") | ||
|
||
return node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'node' feels a bit weird, since we do: self._dataloader = self._setup_dataloader(...)
Should we rename it to dataloader?
|
||
# TODO: add multi-dataset mixer | ||
if num_workers == 0: | ||
_Mapper = Mapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit torn about if getting mapper, sampler, pin memory, etc should be an utility shared across all recipes, or if it should be exposed. No strong opinion, just thinking outloud
if pin_memory: | ||
node = PinMemory(node) | ||
if num_workers > 0: | ||
node = Prefetcher(node, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this PreFetcher=2 different than the previous prefect_factor=8?
All data related setup happens here. Currently this recipe only supports | ||
Map-style Datasets which fit into memory and an option for random shuffling. | ||
Samplers, iterable datasets, and streaming datasets are not supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess this needs to be updated
prefetch_factor: 2 | ||
seed: null | ||
|
||
multi_datasets: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally this is too verbose to parse for me, and even in the recipe there are just too many nested dictionaries. Ideally, I would like to achieve this type of UI in the config for datasets:
datasets:
- _component_: torchtune.datasets...
weight: 1.0
subset: ...
- _component_: torchtune.datasets...
weight: 1.0
...
or something similar so all I have to do is specify the dataset I want and the weight. As it is I have multi_datasets -> datasets -> dataset just to specify the dataset builder. Maybe this is very ideal, but other libraries such as Axolotl are able to do this.
I am aware there's a few challenges to having this:
- MultiDatasetSampler requires passing in dictionaries for datasets and weights
- weight is not a valid argument for instantiating dataset components
I'm wondering if there's a way we can do this all for the user in a builder. For example:
datasets: ListConfig
for cfg_dataset in datasets:
weights[k] = cfg_dataset.pop("weight")
dataset[k] = Prefetcher(config.instantiate(cfg_dataset), prefetch_factor)
dataloader = get_multi_datasets(datasets, weights, cfg_dataloader)
stop_criterion imo should be moved to the dataloader config
) | ||
weights, datasets = {}, {} | ||
cfg_datasets = cfg_multi_datasets.datasets | ||
for k, cfg_and_weight in cfg_datasets.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of this logic needs to be in the builder. we do not want to expose torchdata internals in each recipe. I am totally okay with creating a new file in torchtune/data
that contains functions that set up torchdata dataloaders and nodes.
Another ideal that I'm curious if we can achieve, can we unify the UX for multi datasets and single datasets? i.e., if we had a get_dataloader
method, you can pass in a single dataset or a multi dataset and the call is the same in the recipe regardless of what the user specifies in the config
torchtune/datasets/_sft.py
Outdated
), | ||
) | ||
|
||
return node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this returns a node, not a dataloader right? Are users still able to access the underlying hugging face data?
torchtune/datasets/_sft.py
Outdated
|
||
|
||
@requires_torchdata | ||
def SFTDatasetNode( # noqa[N802] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems we're moving away from the class-based dataset abstraction and more of a function that returns the node configured with user parameters.
Curious if it would be better UX to just abandon the SFTDataset class (after the full migration) and keep Transform classes for each kinda of dataset (Instruct, Chat, MM, SFT, Preference) which is passed into a generic node builder
torchtune/datasets/_sft.py
Outdated
dataset = dataset.shuffle(seed=seed) | ||
node = IterableWrapper(dataset) | ||
else: | ||
sampler = DistributedSampler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, which of these nodes are specific to a single dataset vs global for the dataloader?
streaming: bool = False, | ||
shuffle: bool = False, | ||
seed: int = 0, | ||
num_workers: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is specific params such as seed, shuffle, num_workers, etc for each individual dataset a valid use case? My understanding was you can specify this globally at the dataloader level
|
||
# Get global settings | ||
shuffle = cfg_dataloader.shuffle | ||
parallel_method = cfg_dataloader.get("parallel_method", "thread") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in general, I'm quite confused on which parameters belong in the "dataset" abstraction and which belong in the "dataloader" abstraction. As it is, it seems you are using these in both. I would prefer to make this distinction very clear, unless I am missing something you may need to configure per dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trouble is that there's no real distinction between datasets and dataloader with torchdata.nodes, they're all just Iterators that can be composed together. Currently it's exposed but we can introduce restrictions if you think it would be simpler
b1b2ab6
to
51d4327
Compare
torchtune/data/_torchdata.py
Outdated
from typing_extensions import TypeAlias # typing.TypeAlias is only in Python 3.10+ | ||
|
||
|
||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should probably go in torchtune/utils/_import_guard.py
error_message = ( | ||
"model_transform returned the following keys: " | ||
f"{keys_str}. Must return 'tokens' and 'mask' as keys." | ||
class SFTTransform(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may be the cleanest way to separate the old way and the torchdata way, but what are your thoughts on keeping the class as SFTDataset, but remove all the HF load dataset logic out? we can keep source as a class attribute but not load it, then in the recipe or a utility used in the recipe: if we are using torchdata, use the load_hf_dataset you added here. else, use a barebones load utility with the original behavior.
When we eventually migrate to torchdata, we can just remove this switch logic and all the API names will remain the same but everything is swapped out under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is possible, but the tricky part is that right now, SFTDataset is a subclass of torch.utils.data.Dataset and the new version is a subclass of torchdata.nodes.BaseNode, and reconciling them may be difficult. We wouldn't be able to just make a switch because the torch.utils.DataLoader needs to be swapped out as well for the new get_loader
builder wherever that is being called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @pbontrager for thoughts
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1929 +/- ##
==========================================
+ Coverage 7.35% 66.62% +59.26%
==========================================
Files 277 325 +48
Lines 15422 18444 +3022
==========================================
+ Hits 1135 12289 +11154
+ Misses 14287 6155 -8132 ☔ View full report in Codecov by Sentry. |
@@ -235,3 +235,17 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | |||
if packed: | |||
raise ValueError("Multimodal datasets don't support packing yet.") | |||
return ds | |||
|
|||
|
|||
def the_cauldron_transform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I absolutely love this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although, we may not need the "column_map" b/c if it's truly just for the cauldron, then we know what the columns will be, no? It would be much clearer IMO to have e.g. image_column, text_column, etc.
num_tokens = 0 | ||
|
||
loader: Iterable | ||
if self.profile_mode == "model_only": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very noob question but are all these profiling changes purely debug code? Or will we actually need to make changes to how we expose the profiler in order to be able to disentangle time spent on dataloading vs time spent in the model now?
torchtune/data/_utils.py
Outdated
# Need to lazy import to avoid circular dependency | ||
from torchtune.training._distributed import get_world_size_and_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor point but we should avoid doing this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy importing? I think the workaround would be to up-level torchtune.training._distributed to eg torchtune.distributed. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry I missed this. Yeah it's literally just an if-else so can always copy-paste. Personally I slightly prefer that to lazy import since it feels hacky to do this from within the same library. But no strong preference here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fyi our hero @joecummings addressed this in #2155. You can now do from torchtune.utils import get_world_size_and_rank
at the top of the file
|
||
loader = get_dataloader( | ||
dataset=dataset, | ||
model_transform=SFTTransform(model_transform=self._tokenizer), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's up with this line? I thought we already pass the transform in load_hf_dataset
. And why is it just the tokenizer instead of a multimodal transform?
@@ -10,3 +10,12 @@ | |||
_SUPPORTS_FLEX_ATTENTION = ( | |||
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5) | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Ed says we shouldn't do this anymore :/
https://fb.workplace.com/groups/pytorch.dev/permalink/1677770972801376/
key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "") | ||
assert key not in weights, f"Duplicate dataset name {key}" | ||
weights[key] = float(cfg_dataset.pop("weight")) | ||
datasets[key] = self._setup_one_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From call: Let's inline this
def iter_timer( | ||
iterable: Iterable[T], barrier: bool | ||
) -> Iterator[Tuple[T, float, float, float]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's drop this entirely
cacf2b8
to
5ed8382
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience on the review here! I left a handful of pretty minor comments, but this is looking great. We should also make sure to get CI passing, but after that this looks good to me.
texts_col: str = "texts", | ||
images_col: str = "images", | ||
new_system_prompt: Optional[str] = None, | ||
) -> SFTTransform: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we're getting close to landing i am gonna play nit 👑 . Let's add a docstring here since it's a public API
|
||
# Fine-tuning arguments | ||
epochs: 1 | ||
max_steps_per_epoch: 50 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm: this now becomes mandatory, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's right, is this an issue for any users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, looking at the code, this is only required for progress bar calculation. Depending on stop_criteria (ie default is cycle-until-all-datasets-exhausted) it's non-trivial to know a-priori the total number of steps we'll get before StopIteration is thrown due to randomness. If we don't care about progress bar we can drop this requirement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, thought I responded. It's not an issue, but we need to make sure to not break anyone when we move this into our core recipes (since we currently default max_steps_per_epoch=null
everywhere). Progress bar I don't care about too much, though we do also need it for learning rate scheduler, right? But yeah I agree with your point that in general we can no longer just infer it in general. Not a blocker for this PR anyways
log_every_n_steps: 1 | ||
log_peak_memory_stats: True | ||
|
||
profile_mode: null # dataloader_only | model_only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
distributed training and can be run on a single node (1 to 8 GPUs). | ||
|
||
Features: | ||
- TorchData. Map and Streaming HuggingFace datasets, and multi-dataset mixing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but I feel like you're not selling the features very much here 😛. I think it's OK for now, but eventually we should have a docs page or something we can point to explaining the different features
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#2156 To track issue
torchtune/data/_utils.py
Outdated
# Need to lazy import to avoid circular dependency | ||
from torchtune.training._distributed import get_world_size_and_rank | ||
|
||
streaming = load_dataset_kwargs.get("streaming", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason this isn't an explicit argument? Seems important enough to be worth exposing directly
torchtune/data/_utils.py
Outdated
if packed: | ||
raise ValueError("Multimodal datasets don't support packing yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultimately this isn't the right place for this check, right? I assume we will need to do it somewhere else (or at least in another way) once we onboard text datasets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right, I'm assuming that packer behaviour is going to need to be configurable, so we'll probably need to change this signature entirely anyways. I'll just drop it as an option and catch this in lora_finetune_distributed_multi_dataset for now, wdyt? @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that works. We can revisit when we're migrating text recipes
""" | ||
self._checkpointer = config.instantiate( | ||
cfg_checkpointer, | ||
resume_from_checkpoint=self._resume_from_checkpoint, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll need this now that #2006 has landed
resume_from_checkpoint=self._resume_from_checkpoint, | |
should_load_recipe_state=self._resume_from_checkpoint, |
m.lora_a.to_empty(device=lora_device) | ||
m.lora_b.to_empty(device=lora_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From #2139. Sorry about all these merges, it's been a busy week on the repo
m.lora_a.to_empty(device=lora_device) | |
m.lora_b.to_empty(device=lora_device) | |
m.to_empty(device=lora_device) |
is_dora = False | ||
for m in model.modules(): | ||
if hasattr(m, "initialize_dora_magnitude"): | ||
is_dora = True | ||
m.initialize_dora_magnitude() | ||
if is_dora: | ||
load_dora_magnitudes(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more
is_dora = False | |
for m in model.modules(): | |
if hasattr(m, "initialize_dora_magnitude"): | |
is_dora = True | |
m.initialize_dora_magnitude() | |
if is_dora: | |
load_dora_magnitudes(model) | |
for m in model.modules(): | |
if hasattr(m, "initialize_dora_magnitude"): | |
m.initialize_dora_magnitude() |
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." | ||
) | ||
|
||
_, rank = training.get_world_size_and_rank() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can do this now
_, rank = training.get_world_size_and_rank() | |
_, rank = utils.get_world_size_and_rank() |
torchtune/data/_utils.py
Outdated
|
||
# TODO: Remove lazy import when we can | ||
# see: https://github.com/pytorch/torchtune/issues/2151 | ||
from torchtune.training._distributed import get_world_size_and_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realized the comment I responded to was outdated, so just in case you miss it: can now use from torchtune.utils import get_world_size_and_rank
at the top of the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK a few more small merges to do, after that I think we're good. Thanks so much for building out this integration! Having proper multi-dataset and streaming support (not to mention the potential for more performant multimodal dataloading) is gonna be huge for our users.
6f64f52
to
139d7a7
Compare
* Llama 3.3 70B (pytorch#2124) * Llama 3.3 readme updates (pytorch#2125) * update configs (pytorch#2107) Co-authored-by: Felipe Mello <[email protected]> * Reduce logging output for distributed KD (pytorch#2120) * Support Early Exit Loss and/or Layer Dropout (pytorch#1076) Co-authored-by: ebsmothers <[email protected]> * Update checkpointing directory (pytorch#2074) Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: vancoyendall <[email protected]> * pass correct arg (pytorch#2127) Co-authored-by: Felipe Mello <[email protected]> * update configs (pytorch#2128) Co-authored-by: Felipe Mello <[email protected]> * fix qat_lora_test (pytorch#2131) Co-authored-by: Felipe Mello <[email protected]> * guard ckpt imports (pytorch#2133) Co-authored-by: Felipe Mello <[email protected]> * [bug fix] add parents=True (pytorch#2136) Co-authored-by: Felipe Mello <[email protected]> * [bug fix] re-add model (pytorch#2135) Co-authored-by: Felipe Mello <[email protected]> * Update save sizes into GiB (pytorch#2143) * [bug fix] remove config download when source is kaggle (pytorch#2144) Co-authored-by: Felipe Mello <[email protected]> * [fix] remove "with_suffix" (pytorch#2146) Co-authored-by: Felipe Mello <[email protected]> * DoRA fixes (pytorch#2139) Co-authored-by: Mircea Mironenco <[email protected]> * [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150) * Small readme, config updates (pytorch#2157) * Using `FormattedCheckpointFiles` in configs (pytorch#2147) * Move ``get_world_size_and_rank`` to utils (pytorch#2155) * Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006) Co-authored-by: Saurabh Mishra <[email protected]> * torchdata integration - multi-dataset and streaming support (pytorch#1929) * Allow higher version of lm-eval (pytorch#2165) * Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167) * [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164) --------- Co-authored-by: Philip Bontrager <[email protected]> Co-authored-by: ebsmothers <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Felipe Mello <[email protected]> Co-authored-by: Joe Cummings <[email protected]> Co-authored-by: Mostafa Elhoushi <[email protected]> Co-authored-by: vancoyendall <[email protected]> Co-authored-by: Mircea Mironenco <[email protected]> Co-authored-by: salman <[email protected]> Co-authored-by: Saurabh Mishra <[email protected]> Co-authored-by: Saurabh Mishra <[email protected]> Co-authored-by: Andrew Ho <[email protected]> Co-authored-by: Eugen Hotaj <[email protected]>
Note! This requires torchdata nightly to be installed to work correctly.
Test multi-dataset training command:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 compile=True dataloader.num_workers=4
Test multi-dataset command with dataloader_only mode:
tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td dataloader.pin_memory=True use_torchdata=True dataloader.parallel_method=thread max_steps_per_epoch=1000 profile_mode=dataloader_only compile=True dataloader.num_workers=4
Benchmarking on 8xA100
for 200 steps, full model training, batch_size: 4, gradient_accumulation_steps: 1
TWFB: Sum of Time Waiting For Batch (max across all ranks) divided by sum of step times
Single Datasets are run with OCRVQA
Multi-Dataset was run with:
ocrvqa, docvqa, dvqa, tabmwp with equal weighting for all datasets.
Multi-dataset runs much slower, guessing because one of the datasets (dvqa?) requires more padding than ocrvqa.
Launch commands:
0: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=false max_steps_per_epoch=200 compile=True
1: (same as 2 but with 3 of the datasets removed in the config)
2: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True
3: tune run --nnodes=1 --nproc_per_node=8 lora_finetune_distributed_td --config llama3_2_vision/11B_lora_td use_torchdata=true max_steps_per_epoch=200 compile=True dataloader.streaming=true
Please have a look at the code set up, and how this composability can help with streaming datasets, and multi-dataset mixing. You can think of this as approaches to replace torch.utils.data.DataLoader, while introducing more flexible parallelism schemes, eg instead of just one multiprocess worker setting, you could do multi-threading, pipeline parallelism, etc. It also enables more powerful composability IMO.
I have done some single-device benchmarking on a machine with A100 40gb, both with and without the model, performance is on par or better than standard dataloader.
TODO: fill in below
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
*
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example