Skip to content

Commit

Permalink
[DCP] Minor refactor to avoid passing the epoch in the checkpoint sav…
Browse files Browse the repository at this point in the history
…e APIs
  • Loading branch information
Saurabh Mishra committed Nov 19, 2024
1 parent 1b1342e commit cec9587
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,7 @@ def _setup_data(

return sampler, dataloader

def save_checkpoint_async(
self,
epoch: int,
) -> None:
def save_checkpoint_async(self) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
Expand Down Expand Up @@ -746,7 +743,7 @@ def save_checkpoint_async(

dcp_saver.save_checkpoint(
ckpt_dict,
epoch=epoch,
epoch=self.epochs_run,
save_async=True,
)

Expand All @@ -755,10 +752,7 @@ def save_checkpoint_async(
f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs"
)

def save_checkpoint(
self,
epoch: int,
) -> None:
def save_checkpoint(self) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
Expand All @@ -772,12 +766,12 @@ def save_checkpoint(
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs
intermediate_checkpoint = self.epochs_run < self.total_epochs

# If async checkpointing is enabled, intermediate checkpoints will
# be saved asynchronously.
if intermediate_checkpoint and self._enable_async_checkpointing:
self.save_checkpoint_async(epoch)
self.save_checkpoint_async()
return

if self._is_rank_zero:
Expand Down Expand Up @@ -845,7 +839,7 @@ def save_checkpoint(

self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
epoch=self.epochs_run,
intermediate_checkpoint=intermediate_checkpoint,
)

Expand Down Expand Up @@ -1023,7 +1017,8 @@ def train(self) -> None:
self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
# Save the checkpoint for the current epoch
self.save_checkpoint()

self._profiler.stop()

Expand Down

0 comments on commit cec9587

Please sign in to comment.