diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 48aa57e84..559fca84b 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -743,10 +743,8 @@ def save_checkpoint( index_file_name = TORCH_INDEX_FNAME index_path = Path.joinpath( - self._output_dir, - f"epoch_{epoch}", - index_file_name, - ).with_suffix(".json") + self._output_dir, f"epoch_{epoch}", index_file_name + ) index_data = { "metadata": {"total_size": total_size}, diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 82f60fc7d..770a3f889 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -32,8 +32,8 @@ # https://github.com/huggingface/peft/blob/d13d7a401ccf4808aaaf76480fea09a4cf4ac1f5/src/peft/config.py#L259C21-L259C32 ADAPTER_CONFIG_FNAME = "adapter_config" ADAPTER_MODEL_FNAME = "adapter_model" -SAFETENSOR_INDEX_FNAME = "model.safetensors.index" -TORCH_INDEX_FNAME = "pytorch_model.bin.index" +SAFETENSOR_INDEX_FNAME = "model.safetensors.index.json" +TORCH_INDEX_FNAME = "pytorch_model.bin.index.json" # standardize checkpointing SHARD_FNAME = "ft-model-{cpt_idx}-of-{num_shards}"