diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d243727e0..4ea8ce0d9 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -335,6 +335,7 @@ def test_parse_arguments(job_config): _, _, _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -348,7 +349,7 @@ def test_parse_arguments_defaults(job_config): assert "torch_dtype" not in job_config_defaults assert job_config_defaults["use_flash_attn"] is False assert "save_strategy" not in job_config_defaults - model_args, _, training_args, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( + model_args, _, training_args, _, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_defaults ) assert str(model_args.torch_dtype) == "torch.bfloat16" @@ -360,14 +361,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig)