diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index a0690f9002..506dc6f546 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -701,10 +701,7 @@ def _setup_data( return sampler, dataloader - def save_checkpoint_async( - self, - epoch: int, - ) -> None: + def save_checkpoint_async(self) -> None: """ Checkpoint the state of the recipe. The constructed checkpoint state dict contains the following information: @@ -746,7 +743,7 @@ def save_checkpoint_async( dcp_saver.save_checkpoint( ckpt_dict, - epoch=epoch, + epoch=self.epochs_run, save_async=True, ) @@ -755,10 +752,7 @@ def save_checkpoint_async( f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs" ) - def save_checkpoint( - self, - epoch: int, - ) -> None: + def save_checkpoint(self) -> None: """ Checkpoint the state of the recipe. The constructed checkpoint state dict contains the following information: @@ -772,12 +766,12 @@ def save_checkpoint( # final dict passed onto the checkpointer checkpoint_dict = {} - intermediate_checkpoint = epoch + 1 < self.total_epochs + intermediate_checkpoint = self.epochs_run < self.total_epochs # If async checkpointing is enabled, intermediate checkpoints will # be saved asynchronously. if intermediate_checkpoint and self._enable_async_checkpointing: - self.save_checkpoint_async(epoch) + self.save_checkpoint_async() return if self._is_rank_zero: @@ -845,7 +839,7 @@ def save_checkpoint( self._checkpointer.save_checkpoint( checkpoint_dict, - epoch=epoch, + epoch=self.epochs_run, intermediate_checkpoint=intermediate_checkpoint, ) @@ -1023,7 +1017,8 @@ def train(self) -> None: self._profiler.step() self.epochs_run += 1 - self.save_checkpoint(epoch=curr_epoch) + # Save the checkpoint for the current epoch + self.save_checkpoint() self._profiler.stop()