Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj committed Sep 18, 2024
1 parent fb1dcc9 commit bcc17b1
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@
from trl import SFTConfig, SFTTrainer
import transformers

# Local
# First Party
from build.utils import get_highest_checkpoint

# Local
from tuning.config import configs, peft_config
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig,
Expand All @@ -64,15 +66,14 @@
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
is_pretokenized_dataset,
validate_data_args,
)
from tuning.utils.merge_model_utils import(
post_process_vLLM_adapters_new_tokens,
)


def train(
model_args: configs.ModelArguments,
Expand Down Expand Up @@ -654,22 +655,25 @@ def main():
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)


# post process lora
if isinstance(tune_config, peft_config.LoraConfig):
try:
checkpoint_dir = job_config.get("save_model_dir")
if not checkpoint_dir:
checkpoint_dir = os.path.join(
training_args.output_dir, get_highest_checkpoint(training_args.output_dir)
training_args.output_dir,
get_highest_checkpoint(training_args.output_dir),
)
print(training_args)
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Exception encountered while lora post-processing model: {e}")
write_termination_log(
f"Exception encountered while lora post-processing model: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)


if __name__ == "__main__":
main()

0 comments on commit bcc17b1

Please sign in to comment.