diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 23fbec7a8b..d75227aeac 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -409,11 +409,9 @@ def get_recipe_checkpoint_path( if recipe_checkpoint: recipe_checkpoint_path = os.path.join(output_dir, recipe_checkpoint) else: - # Look for the recipe_state in /recipe_state - tentative_folder_path = os.path.join(output_dir, RECIPE_STATE_DIRNAME) - for file_name in os.listdir(tentative_folder_path): - if file_name.startswith("recipe_state" + "."): - recipe_checkpoint_path = os.path.join(tentative_folder_path, file_name) + recipe_checkpoint_path = os.path.join( + output_dir, RECIPE_STATE_DIRNAME, "recipe_state.pt" + ) # TODO: improve this msg if not recipe_checkpoint_path or not os.path.exists(recipe_checkpoint_path):