diff --git a/build/launch_training.py b/build/launch_training.py index 143987ad6..af02575bf 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -28,6 +28,7 @@ # Local from tuning import sft_trainer from tuning.utils.merge_model_utils import create_merged_model +from tuning.config.tracker_configs import TrackerConfigFactory from build.utils import process_launch_training_args, get_job_config @@ -62,12 +63,23 @@ def main(): training_args, tune_config, merge_model, + file_logger_config, + aim_config, ) = process_launch_training_args(job_config) original_output_dir = training_args.output_dir with tempfile.TemporaryDirectory() as tempdir: training_args.output_dir = tempdir - sft_trainer.train(model_args, data_args, training_args, tune_config) + tracker_config_args = TrackerConfigFactory( + file_logger_config=file_logger_config, aim_config=aim_config + ) + sft_trainer.train( + model_args=model_args, + data_args=data_args, + train_args=training_args, + peft_config=tune_config, + tracker_configs=tracker_config_args, + ) if merge_model: export_path = os.getenv( @@ -108,7 +120,8 @@ def main(): # copy over any loss logs train_logs_filepath = os.path.join( - training_args.output_dir, sft_trainer.TRAINING_LOGS_FILENAME + training_args.output_dir, + tracker_config_args.file_logger_config.training_logs_filename, ) if os.path.exists(train_logs_filepath): shutil.copy(train_logs_filepath, original_output_dir) diff --git a/build/utils.py b/build/utils.py index 416661531..7025d0978 100644 --- a/build/utils.py +++ b/build/utils.py @@ -25,7 +25,7 @@ from accelerate.commands.launch import launch_command_parser # Local -from tuning.config import configs, peft_config +from tuning.config import configs, peft_config, tracker_configs def txt_to_obj(txt): @@ -67,6 +67,8 @@ def process_launch_training_args(job_config_dict): training_args: configs.TrainingArguments tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig merge_model: bool + file_logger_config: tracker_configs.FileLoggingTrackerConfig + aim_config: tracker_configs.AimConfig """ parser = transformers.HfArgumentParser( dataclass_types=( @@ -75,6 +77,8 @@ def process_launch_training_args(job_config_dict): configs.TrainingArguments, peft_config.LoraConfig, peft_config.PromptTuningConfig, + tracker_configs.FileLoggingTrackerConfig, + tracker_configs.AimConfig, ) ) @@ -84,6 +88,8 @@ def process_launch_training_args(job_config_dict): training_args, lora_config, prompt_tuning_config, + file_logger_config, + aim_config, ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) peft_method_parsed = job_config_dict.get("peft_method") @@ -98,14 +104,25 @@ def process_launch_training_args(job_config_dict): logging.info( "Parameters used to launch training: \ - model_args %s, data_args %s, training_args %s, tune_config %s", + model_args %s, data_args %s, training_args %s, tune_config %s \ + file_logger_config %s aim_config %s", model_args, data_args, training_args, tune_config, + file_logger_config, + aim_config, ) - return model_args, data_args, training_args, tune_config, merge_model + return ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + file_logger_config, + aim_config, + ) def process_accelerate_launch_args(job_config_dict): diff --git a/tests/build/test_utils.py b/tests/build/test_utils.py index da4302153..1bfaabba4 100644 --- a/tests/build/test_utils.py +++ b/tests/build/test_utils.py @@ -46,6 +46,8 @@ def test_process_launch_training_args(job_config): training_args, tune_config, merge_model, + _, + _, ) = process_launch_training_args(job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -59,7 +61,7 @@ def test_process_launch_training_args_defaults(job_config): assert "torch_dtype" not in job_config_defaults assert job_config_defaults["use_flash_attn"] is False assert "save_strategy" not in job_config_defaults - model_args, _, training_args, _, _ = process_launch_training_args( + model_args, _, training_args, _, _, _, _ = process_launch_training_args( job_config_defaults ) assert str(model_args.torch_dtype) == "torch.bfloat16" @@ -70,13 +72,17 @@ def test_process_launch_training_args_defaults(job_config): def test_process_launch_training_args_peft_method(job_config): job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt) + _, _, _, tune_config, merge_model, _, _ = process_launch_training_args( + job_config_pt + ) assert isinstance(tune_config, PromptTuningConfig) assert merge_model is False job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora) + _, _, _, tune_config, merge_model, _, _ = process_launch_training_args( + job_config_lora + ) assert isinstance(tune_config, LoraConfig) assert merge_model is True