Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Nov 26, 2024
1 parent a803d72 commit 51d4327
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 248 deletions.
52 changes: 27 additions & 25 deletions recipes/configs/llama3_2_vision/11B_lora_td.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ checkpointer:
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.

# TorchData Dataset setup
# TorchData setup
dataloader:
shuffle: True
collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
Expand All @@ -60,29 +60,31 @@ dataloader:
prefetch_factor: 2
seed: null

multi_datasets:
stop_criterion: CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
datasets:
ocrvqa:
weight: 1.0
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata
subset: ocrvqa
dvqa:
weight: 1.0
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata
subset: dvqa
docvqa:
weight: 1.0
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata
subset: docvqa
tabmwp:
weight: 1.0
dataset:
_component_: torchtune.datasets.multimodal.the_cauldron_dataset_torchdata
subset: tabmwp
datasets:
- source: HuggingFaceM4/the_cauldron
subset: ocrvqa
split: train
transform:
_component_: torchtune.datasets.multimodal.the_cauldron_transform
weight: 1.0
- source: HuggingFaceM4/the_cauldron
subset: dvqa
split: train
transform:
_component_: torchtune.datasets.multimodal.the_cauldron_transform
weight: 1.0
- source: HuggingFaceM4/the_cauldron
subset: docvqa
split: train
transform:
_component_: torchtune.datasets.multimodal.the_cauldron_transform
weight: 1.0
- source: HuggingFaceM4/the_cauldron
subset: tabmwp
split: train
transform:
_component_: torchtune.datasets.multimodal.the_cauldron_transform
weight: 1.0

# torch.utils.data.DataLoader Dataset setup, single dataset only
classic_dataloader:
Expand All @@ -98,7 +100,7 @@ use_torchdata: true

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
max_steps_per_epoch: 50
batch_size: 4
gradient_accumulation_steps: 1
optimizer:
Expand Down
160 changes: 77 additions & 83 deletions recipes/lora_finetune_distributed_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,15 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler

from torchdata.nodes import (
BaseNode,
Batcher,
Loader,
Mapper,
MultiDatasetWeightedSampler,
ParallelMapper,
PinMemory,
Prefetcher,
T,
)
from torchdata.nodes.samplers.multi_dataset_weighted_sampler import StopCriteria
from torchdata.nodes import Loader, T
from torchdata.nodes.samplers.stop_criteria import StopCriteria
from torchtune import config, modules, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.data._torchdata import DatasetType
from torchtune.data._utils import get_dataloader, get_multi_dataset, load_hf_dataset
from torchtune.datasets import ConcatDataset
from torchtune.datasets._sft import SFTTransform
from torchtune.modules.peft import (
DoRALinear,
get_adapter_params,
Expand Down Expand Up @@ -320,7 +313,7 @@ def setup(self, cfg: DictConfig) -> None:
if cfg.get("use_torchdata", True):
self._dataloader = self._setup_data_td(
cfg_dataloader=cfg.dataloader,
cfg_multi_datasets=cfg.multi_datasets,
cfg_datasets=cfg.datasets,
batch_size=cfg.batch_size,
)
else:
Expand Down Expand Up @@ -618,40 +611,41 @@ def _setup_lr_scheduler(
def _setup_one_dataset(
self,
cfg_dataset: DictConfig,
global_streaming: bool,
global_shuffle: bool,
global_parallel_method: str,
global_streaming: bool,
global_num_workers: int,
) -> BaseNode:
) -> DatasetType:
streaming = cfg_dataset.pop("streaming", global_streaming)
parallel_method = cfg_dataset.pop("parallel_method", global_parallel_method)
shuffle = cfg_dataset.pop("shuffle", global_shuffle)
parallel_method = cfg_dataset.pop("parallel_method", global_parallel_method)
num_workers = cfg_dataset.pop("num_workers", global_num_workers)

return config.instantiate(
cfg_dataset,
self._tokenizer,
# Instantiate dataset transform
assert "transform" in cfg_dataset, "transform must be specified in dataset"
transform = config.instantiate(cfg_dataset.pop("transform"))

log.info(f"Instantiating dataset {cfg_dataset}")
return load_hf_dataset(
**cfg_dataset,
transform=transform,
streaming=streaming,
shuffle=shuffle,
parallel_method=parallel_method,
num_workers=num_workers,
)

return ds

def _setup_data_td(
self,
cfg_dataloader: DictConfig,
cfg_multi_datasets: DictConfig,
cfg_datasets: ListConfig,
batch_size: int,
) -> Loader:
"""
All torchdata related setup happens here. Currently this recipe supports
both Map and Streaming datasets (from HuggingFace datasets), and mixing multiple
datasets (can be mix of Map and Streaming).
"""
world_size, rank = training.get_world_size_and_rank()

# Get global settings
shuffle = cfg_dataloader.shuffle
parallel_method = cfg_dataloader.get("parallel_method", "thread")
Expand All @@ -660,69 +654,66 @@ def _setup_data_td(
num_workers = cfg_dataloader.get("num_workers", 0)
pin_memory = cfg_dataloader.get("pin_memory", True)
collate_fn = cfg_dataloader.collate_fn
prefetch_factor = cfg_dataloader.get("prefetch_factor", 2)
prefetch_factor = cfg_dataloader.get("prefetch_factor", 6)

stop_criterion = cfg_multi_datasets.get(
"stop_criterion", StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
# Multi-Dataset Stop Criteria
stop_criteria = cfg_dataloader.get(
"stop_criteria", StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
)
weights, datasets = {}, {}
cfg_datasets = cfg_multi_datasets.datasets
for k, cfg_and_weight in cfg_datasets.items():
weights[k] = float(cfg_and_weight.weight)
datasets[k] = Prefetcher(
self._setup_one_dataset(
cfg_dataset=cfg_and_weight.dataset,
global_shuffle=shuffle,
global_parallel_method=parallel_method,
global_streaming=streaming,
global_num_workers=num_workers,
),
prefetch_factor=prefetch_factor,
for idx, cfg_dataset in enumerate(cfg_datasets):
dataset_name = cfg_dataset.pop("name", None)
if dataset_name is None:
dataset_name = cfg_dataset.get("subset", None)
key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "")
assert key not in weights, f"Duplicate dataset name {key}"
weights[key] = float(cfg_dataset.pop("weight"))
datasets[key] = self._setup_one_dataset(
cfg_dataset=cfg_dataset,
global_shuffle=shuffle,
global_parallel_method=parallel_method,
global_streaming=streaming,
global_num_workers=num_workers,
)

# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

# TODO: add multi-dataset mixer
if num_workers == 0:
_Mapper = Mapper # noqa[N806]
else:
_Mapper = partial( # noqa[N806]
ParallelMapper,
num_workers=num_workers,
method=parallel_method,
collate_fn = (
partial(
_get_component_from_path(collate_fn),
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if len(cfg_datasets) == 1:
node = next(iter(datasets.values()))
else:
node = MultiDatasetWeightedSampler(
source_nodes=datasets,
if not packed
else padded_collate_packed
)
if len(datasets) > 1:
dataset = get_multi_dataset(
datasets=datasets,
weights=weights,
stop_criterion=stop_criterion,
stop_criteria=stop_criteria,
)
node = Batcher(node, batch_size, drop_last=True)
node = _Mapper(
node,
map_fn=(
partial(
collate_fn, # noqa
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
else:
dataset = next(iter(datasets.values()))

loader = get_dataloader(
dataset=dataset,
model_transform=SFTTransform(model_transform=self._tokenizer),
batch_size=batch_size,
collate_fn=collate_fn,
packed=packed,
drop_last=True,
num_workers=num_workers,
parallel_method=parallel_method,
prefetch_factor=prefetch_factor,
pin_memory=pin_memory,
)
if pin_memory:
node = PinMemory(node)
if num_workers > 0:
node = Prefetcher(node, prefetch_factor)

log.info("TorchData nodes are initialized")

return Loader(node)
return loader

def _setup_data(
self,
Expand Down Expand Up @@ -801,24 +792,27 @@ def save_checkpoint(

intermediate_checkpoint = epoch + 1 < self.total_epochs

if self._is_rank_zero:
log.info(
"Saving checkpoint. This may take some time. Retrieving full model state dict..."
)
start = time.perf_counter()
utils.log_rank_zero(
log,
"Saving checkpoint. This may take some time. Retrieving full model state dict...",
)
start = time.perf_counter()

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.get_full_model_state_dict(
self._model,
state_dict = self._model.state_dict()
if self._save_adapter_weights_only:
state_dict = get_adapter_state_dict(state_dict, device=None)

cpu_state_dict = training.gather_cpu_state_dict(
state_dict,
self._is_rank_zero,
device=self._device,
trainable_only=self._save_adapter_weights_only,
)
if self._is_rank_zero:
log.info(
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
)
utils.log_rank_zero(
log,
f"Getting full model state dict took {time.perf_counter() - start:.2f} secs",
)

if intermediate_checkpoint:
if self._is_rank_zero:
Expand Down
43 changes: 43 additions & 0 deletions torchtune/data/_torchdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import functools
from typing import Any, Callable, Iterable, Iterator, Mapping

from typing_extensions import TypeAlias # typing.TypeAlias is only in Python 3.10+


try:
from torchdata.nodes import BaseNode, Loader # noqa

_TORCHDATA_INSTALLED = True
DatasetType: TypeAlias = BaseNode[Mapping[str, Any]] # type: ignore
except ImportError as e:
# If we fail to import torchdata, define some stubs to make typechecker happy
_TORCHDATA_INSTALLED = False
DatasetType: TypeAlias = Iterator[Mapping[str, Any]] # type: ignore

class Loader(Iterable):
def __init__(self, *args, **kwargs):
assert_torchdata_installed()


MIN_VERSION = "0.10.0"


def assert_torchdata_installed():
if not _TORCHDATA_INSTALLED:
raise ImportError(
f"torchdata is not installed, or the current version is too old. "
f"Please (re-)install it with `pip install torchdata>={MIN_VERSION}`. "
)


def requires_torchdata(func: Callable) -> Callable:
"""
Decorator to check if torchdata is installed and raise an ImportError if not.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
assert_torchdata_installed()
return func(*args, **kwargs)

return wrapper
Loading

0 comments on commit 51d4327

Please sign in to comment.