Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: move removal of lm_head to save method #313

Closed

Conversation

anhuong
Copy link
Collaborator

@anhuong anhuong commented Aug 23, 2024

Description of the change

  • Utilize the removal of duplicated lm_head weight in granite models with llama arch in sft_trainer instead of just in accelerate_launch script. Removal of lm_head weight still occurs in the save_model_dir OR the last checkpoint.
  • Now on call to sft_trainer.save() the lm_head weight will automatically be removed
  • No longer need to reload the checkpoint to remove lm_head

Related issue number

How to verify the PR

Currently testing

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

Comment on lines -616 to +663
if training_args.save_model_dir:
try:
try:
if training_args.save_model_dir:
save(
path=training_args.save_model_dir,
trainer=trainer,
log_level=training_args.log_level,
)
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
write_termination_log(
f"Failed to save model to {training_args.save_model_dir}: {e}"
else:
# if granite with llama arch, remove lm_head in last checkpoint
save(
path=get_highest_checkpoint(training_args.output_dir),
trainer=trainer,
log_level=training_args.log_level,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently removing lm_head from the model that is being saved at save_model_dir OR the last checkpoint. A few questions...

  • Do we need to remove from last checkpoint? This is for the use case if someone doesn't specify save_model_dir and retains current behavior but it is a little weird for the final checkpoint to be different than the rest of the checkpoints.
  • Do we want to remove from each checkpoint? This would require loading up each of the checkpoints again and resaving it, unless SFTTrainer.trainer has access to the checkpoints still but AFAIK it only has access to the last checkpoint

@anhuong anhuong force-pushed the refactor-remove-lmhead branch from b54eed5 to ee0299e Compare August 23, 2024 18:46
@anhuong
Copy link
Collaborator Author

anhuong commented Sep 5, 2024

will close in favor of PR: #333

@anhuong anhuong closed this Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant