Skip to content

Commit

Permalink
Update input args to max_seq_length and training_data_path (foundatio…
Browse files Browse the repository at this point in the history
…n-model-stack#94)

* update max_model_length to max_seq_length to match sfttrainer

Signed-off-by: Anh-Uong <[email protected]>

* update data_path to training_data_path and tests

Signed-off-by: Anh-Uong <[email protected]>

---------

Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong committed Apr 3, 2024
1 parent a715a83 commit 1945a52
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 23 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export CUDA_VISIBLE_DEVICES=0

python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
Expand Down Expand Up @@ -90,7 +90,7 @@ torchrun \
--master_port=1234 \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--bf16 True \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
Expand Down Expand Up @@ -120,7 +120,7 @@ For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply re
```bash
python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 40 \
--per_device_train_batch_size 4 \
Expand Down
2 changes: 1 addition & 1 deletion examples/prompt_tuning_twitter_complaints/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ torchrun \
--master_port=1234 \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--peft_method pt \
--tokenizer_name_or_path $MODEL_PATH \
Expand Down
12 changes: 6 additions & 6 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MODEL_NAME = "Maykeye/TinyLLama-v0"
BASE_PEFT_KWARGS = {
"model_name_or_path": MODEL_NAME,
"data_path": TWITTER_COMPLAINTS_DATA,
"training_data_path": TWITTER_COMPLAINTS_DATA,
"num_train_epochs": 5,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
Expand All @@ -51,7 +51,7 @@
"dataset_text_field": "output",
"use_flash_attn": False,
"torch_dtype": "float32",
"model_max_length": 4096,
"max_seq_length": 4096,
"peft_method": "pt",
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
Expand Down Expand Up @@ -80,12 +80,12 @@ def test_helper_causal_lm_train_kwargs():
assert model_args.use_flash_attn == False
assert model_args.torch_dtype == "float32"

assert data_args.data_path == TWITTER_COMPLAINTS_DATA
assert data_args.training_data_path == TWITTER_COMPLAINTS_DATA
assert data_args.response_template == "\n### Label:"
assert data_args.dataset_text_field == "output"

assert training_args.num_train_epochs == 5
assert training_args.model_max_length == 4096
assert training_args.max_seq_length == 4096
assert training_args.save_strategy == "epoch"

assert tune_config.prompt_tuning_init == "RANDOM"
Expand All @@ -105,10 +105,10 @@ def test_run_train_requires_output_dir():
sft_trainer.train(model_args, data_args, training_args, tune_config)


def test_run_train_fails_data_path_not_exist():
def test_run_train_fails_training_data_path_not_exist():
"""Check fails when data path not found."""
updated_output_path = copy.deepcopy(BASE_PEFT_KWARGS)
updated_output_path["data_path"] = "fake/path"
updated_output_path["training_data_path"] = "fake/path"
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
updated_output_path
)
Expand Down
4 changes: 2 additions & 2 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ModelArguments:

@dataclass
class DataArguments:
data_path: str = field(
training_data_path: str = field(
default=None, metadata={"help": "Path to the training data in JSONL format."}
)
response_template: str = field(
Expand All @@ -61,7 +61,7 @@ class DataArguments:
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
# optim: str = field(default=DEFAULT_OPTIMIZER)
model_max_length: int = field(
max_seq_length: int = field(
default=DEFAULT_CONTEXT_LENGTH,
metadata={
"help": "Maximum sequence length. Sequences will be right padded \
Expand Down
19 changes: 8 additions & 11 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,14 @@ def train(
response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:]
# TODO: This is actually max_seq_length and not model_max_length. we should not override
# model_max_length as in current main. We need to change name of this parameter we expose
# to users.
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
logger.info("Model max length %s, model_max_length")
if train_args.model_max_length > tokenizer.model_max_length:

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
if train_args.max_seq_length > tokenizer.model_max_length:
logger.warning(
"model_max_length %s exceeds tokenizer.model_max_length \
"max_seq_length %s exceeds tokenizer.model_max_length \
%s, using tokenizer.model_max_length %s",
train_args.model_max_length,
train_args.max_seq_length,
tokenizer.model_max_length,
tokenizer.model_max_length,
)
Expand Down Expand Up @@ -197,8 +195,7 @@ def train(
)

# load the data by parsing JSON
# TODO: update arg from data_path to training_data_path since we also have validation_data_path
data_files = {"train": data_args.data_path}
data_files = {"train": data_args.training_data_path}
if data_args.validation_data_path:
data_files["validation"] = data_args.validation_data_path

Expand Down Expand Up @@ -256,7 +253,7 @@ def train(
data_collator=data_collator,
dataset_text_field=data_args.dataset_text_field,
args=train_args,
max_seq_length=model_max_length,
max_seq_length=max_seq_length,
callbacks=callbacks,
peft_config=peft_config,
)
Expand Down

0 comments on commit 1945a52

Please sign in to comment.