Skip to content

Commit

Permalink
[bug fix] remove config download when source is kaggle (#2144)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
felipemello1 and Felipe Mello authored Dec 10, 2024
1 parent d839f69 commit 5370e0d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
7 changes: 0 additions & 7 deletions torchtune/_cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,6 @@ def _download_from_kaggle(self, args: argparse.Namespace) -> None:
try:
output_dir = model_download(model_handle)

# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, REPO_ID_FNAME + ".json")
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)

print(
"Successfully downloaded model repo and wrote to the following locations:",
*list(Path(output_dir).iterdir()),
Expand Down
2 changes: 1 addition & 1 deletion torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def save_checkpoint(
state_dict[
training.ADAPTER_CONFIG
] = convert_weights.tune_to_peft_adapter_config(
state_dict[training.ADAPTER_CONFIG],
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)

Expand Down

0 comments on commit 5370e0d

Please sign in to comment.