Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkho committed Dec 11, 2024
1 parent 30a2e3c commit 5ed8382
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 75 deletions.
93 changes: 23 additions & 70 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,33 +571,6 @@ def _setup_lr_scheduler(
utils.log_rank_zero(log, "Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_one_dataset(
self,
cfg_dataset: DictConfig,
global_streaming: bool,
global_shuffle: bool,
global_parallel_method: str,
global_num_workers: int,
) -> DatasetType:
streaming = cfg_dataset.pop("streaming", global_streaming)
shuffle = cfg_dataset.pop("shuffle", global_shuffle)
parallel_method = cfg_dataset.pop("parallel_method", global_parallel_method)
num_workers = cfg_dataset.pop("num_workers", global_num_workers)

# Instantiate dataset transform
assert "transform" in cfg_dataset, "transform must be specified in dataset"
transform = config.instantiate(cfg_dataset.pop("transform"))

utils.log_rank_zero(log, f"Instantiating dataset {cfg_dataset}")
return load_hf_dataset(
**cfg_dataset,
transform=transform,
streaming=streaming,
shuffle=shuffle,
parallel_method=parallel_method,
num_workers=num_workers,
)

def _setup_data(
self,
cfg_dataloader: DictConfig,
Expand Down Expand Up @@ -630,13 +603,26 @@ def _setup_data(
dataset_name = cfg_dataset.get("subset", None)
key = f"{idx}" + (f"_{dataset_name}" if dataset_name else "")
assert key not in weights, f"Duplicate dataset name {key}"

utils.log_rank_zero(log, f"Instantiating dataset {cfg_dataset}")
# Handle dataset-specific overrides, fallback to cfg_dataloader settings
ds_streaming = cfg_dataset.pop("streaming", streaming)
ds_shuffle = cfg_dataset.pop("shuffle", shuffle)
ds_parallel_method = cfg_dataset.pop("parallel_method", parallel_method)
ds_num_workers = cfg_dataset.pop("num_workers", num_workers)

# Instantiate dataset transform
assert "transform" in cfg_dataset, "transform must be specified in dataset"
transform = config.instantiate(cfg_dataset.pop("transform"))

weights[key] = float(cfg_dataset.pop("weight"))
datasets[key] = self._setup_one_dataset(
cfg_dataset=cfg_dataset,
global_shuffle=shuffle,
global_parallel_method=parallel_method,
global_streaming=streaming,
global_num_workers=num_workers,
datasets[key] = load_hf_dataset(
**cfg_dataset,
transform=transform,
streaming=ds_streaming,
shuffle=ds_shuffle,
parallel_method=ds_parallel_method,
num_workers=ds_num_workers,
)

# Instantiate collate_fn
Expand Down Expand Up @@ -810,14 +796,8 @@ def train(self) -> None:
self._profiler.start()
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
epoch_total_dl_time, epoch_dl_time, epoch_step_time = 0.0, 0.0, 0.0
pbar = tqdm(
total=self._steps_per_epoch,
disable=not (rank == 0),
)
for idx, (batch, dl_t0, dl_dt, dl_dt_global) in enumerate(
iter_timer(self._dataloader, barrier=True)
):
pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
Expand Down Expand Up @@ -902,17 +882,11 @@ def train(self) -> None:
and self._is_rank_zero
):
time_per_step = time.perf_counter() - t0
step_time = time.perf_counter() - dl_t0
epoch_total_dl_time += dl_dt_global
epoch_dl_time += dl_dt
epoch_step_time += step_time
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens
/ (time_per_step * world_size),
"twfb": round(dl_dt, 5),
"twfb_pct": round(dl_dt / (time.perf_counter() - dl_t0), 5),
}
if self._log_peak_memory_stats:
log_dict.update(
Expand Down Expand Up @@ -949,11 +923,7 @@ def train(self) -> None:
self._profiler.step()

if self._is_rank_zero:
total_twfb_pct = round(epoch_total_dl_time / epoch_step_time, 5)
log.info(
f"End of epoch {self.epochs_run}! "
f"{total_twfb_pct=}, {epoch_step_time=}, {epoch_dl_time=}, {epoch_total_dl_time=}",
)
log.info(f"End of epoch {self.epochs_run}!")
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)

Expand All @@ -965,23 +935,6 @@ def cleanup(self) -> None:
destroy_process_group()


def iter_timer(
iterable: Iterable[T], barrier: bool
) -> Iterator[Tuple[T, float, float, float]]:
it = iter(iterable)
while True:
t0 = time.perf_counter()
try:
x = next(it)
except StopIteration:
break
dt = time.perf_counter() - t0
if barrier:
torch.distributed.barrier()
dt_global = time.perf_counter() - t0
yield x, t0, dt, dt_global


@config.parse
def recipe_main(cfg: DictConfig) -> None:
"""
Expand All @@ -996,11 +949,11 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")

config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)

Expand Down
9 changes: 4 additions & 5 deletions torchtune/utils/_import_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import importlib

import torch

# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
_SUPPORTS_FLEX_ATTENTION = (
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5)
)


_TORCHDATA_MIN_VERSION = "0.10.0"
try:
from torchdata.nodes import BaseNode, Loader # noqa

if importlib.util.find_spec("torchdata.nodes") is not None:
_TORCHDATA_INSTALLED = True
except ImportError as e:
else:
_TORCHDATA_INSTALLED = False

0 comments on commit 5ed8382

Please sign in to comment.