Skip to content

Commit

Permalink
add few validations that num checkpoints equals num epoch
Browse files Browse the repository at this point in the history
Signed-off-by: Anh Uong <[email protected]>
  • Loading branch information
anhuong committed Nov 1, 2024
1 parent 69cdfb9 commit 5f3ff51
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_resume_training_from_checkpoint():

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
Expand All @@ -100,6 +101,7 @@ def test_resume_training_from_checkpoint():
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
Expand Down Expand Up @@ -415,6 +417,7 @@ def test_run_causallm_pt_and_inference():

# validate peft tuning configs
_validate_training(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)

Expand Down Expand Up @@ -638,6 +641,7 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):

# validate lora tuning configs
_validate_training(tempdir)
_validate_num_checkpoints(tempdir, train_args.num_train_epochs)
checkpoint_path = _get_checkpoint(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")
Expand Down Expand Up @@ -709,6 +713,7 @@ def test_run_causallm_ft_and_inference(dataset_path):
data_args.training_data_path = dataset_path

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, data_args, tempdir)
_validate_num_checkpoints(tempdir, TRAIN_ARGS.num_train_epochs)
_test_run_inference(checkpoint_path=_get_checkpoint(tempdir))


Expand Down Expand Up @@ -813,6 +818,12 @@ def _validate_logfile(log_file_path, check_eval=False):
if check_eval:
assert "validation_loss" in train_log_contents

def _validate_num_checkpoints(dir_path, expected_num):
checkpoints = [
d for d in os.listdir(dir_path)
if d.startswith("checkpoint")
]
assert len(checkpoints) == expected_num

def _get_adapter_config(dir_path):
with open(os.path.join(dir_path, "adapter_config.json"), encoding="utf-8") as f:
Expand Down

0 comments on commit 5f3ff51

Please sign in to comment.