Skip to content

Commit

Permalink
Add FT unit test and refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Thara Palanivel <[email protected]>
  • Loading branch information
tharapalanivel committed Mar 11, 2024
1 parent f288cd4 commit f958f15
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
BASE_LORA_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS)
BASE_LORA_KWARGS["peft_method"] = "lora"

BASE_FT_KWARGS = copy.deepcopy(BASE_PEFT_KWARGS)
BASE_FT_KWARGS["peft_method"] = ""


def test_helper_causal_lm_train_kwargs():
"""Check happy path kwargs passed and parsed properly."""
Expand Down Expand Up @@ -119,17 +122,12 @@ def test_run_causallm_pt_and_inference():
BASE_PEFT_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)

# validate peft tuning configs
_validate_training(tempdir)
checkpoint_path = os.path.join(tempdir, "checkpoint-5")
adapter_config = _get_adapter_config(checkpoint_path)
assert adapter_config.get("task_type") == "CAUSAL_LM"
assert adapter_config.get("peft_type") == "PROMPT_TUNING"
assert (
adapter_config.get("tokenizer_name_or_path")
== BASE_PEFT_KWARGS["tokenizer_name_or_path"]
)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_PEFT_KWARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
Expand Down Expand Up @@ -171,13 +169,12 @@ def test_run_causallm_lora_and_inference():
BASE_LORA_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)

# validate peft tuning configs
# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = os.path.join(tempdir, "checkpoint-5")
adapter_config = _get_adapter_config(checkpoint_path)
assert adapter_config.get("task_type") == "CAUSAL_LM"
assert adapter_config.get("peft_type") == "LORA"
_validate_adapter_config(adapter_config, "LORA", BASE_LORA_KWARGS)
for module in ["q_proj", "v_proj"]: # default target_modules used
assert module in adapter_config.get("target_modules")

Expand All @@ -203,10 +200,12 @@ def test_run_train_lora_target_modules():
lora_target_modules
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)

# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = os.path.join(tempdir, "checkpoint-5")
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA", BASE_LORA_KWARGS)
for module in lora_target_modules["target_modules"]:
assert module in adapter_config.get("target_modules")

Expand All @@ -222,10 +221,12 @@ def test_run_train_lora_target_modules_all_linear():
lora_target_modules
)
sft_trainer.train(model_args, data_args, training_args, tune_config)
_validate_training(tempdir)

# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = os.path.join(tempdir, "checkpoint-5")
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA", BASE_LORA_KWARGS)
llama_expected_modules = [
"o_proj",
"q_proj",
Expand All @@ -239,13 +240,52 @@ def test_run_train_lora_target_modules_all_linear():
assert module in adapter_config.get("target_modules")


def test_run_causallm_ft_and_inference():
"""Check if we can bootstrap and finetune tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
BASE_FT_KWARGS["output_dir"] = tempdir
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
BASE_FT_KWARGS
)
sft_trainer.train(model_args, data_args, training_args, tune_config)

# validate ft tuning configs
_validate_training(tempdir)
checkpoint_path = os.path.join(tempdir, "checkpoint-5")
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "PROMPT_TUNING", BASE_FT_KWARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def _validate_training(tempdir):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
loss_file_path = "{}/train_loss.jsonl".format(tempdir)
assert os.path.exists(loss_file_path)
assert os.path.getsize(loss_file_path) > 0
train_loss_file_path = "{}/train_loss.jsonl".format(tempdir)
assert os.path.exists(train_loss_file_path) == True
assert os.path.getsize(train_loss_file_path) > 0


def _get_adapter_config(dir_path):
with open(os.path.join(dir_path, "adapter_config.json")) as f:
return json.load(f)


def _validate_adapter_config(adapter_config, peft_type, base_kwargs):
assert adapter_config.get("task_type") == "CAUSAL_LM"
assert adapter_config.get("peft_type") == peft_type
assert (
(
adapter_config.get("tokenizer_name_or_path")
== base_kwargs["tokenizer_name_or_path"]
)
if peft_type == "PROMPT_TUNING"
else True
)

0 comments on commit f958f15

Please sign in to comment.