Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update input args to max_seq_length and training_data_path #94

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export CUDA_VISIBLE_DEVICES=0

python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 4 \
Expand Down Expand Up @@ -90,7 +90,7 @@ torchrun \
--master_port=1234 \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--bf16 True \
--output_dir $OUTPUT_PATH \
--num_train_epochs 5 \
Expand Down Expand Up @@ -120,7 +120,7 @@ For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply re
```bash
python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 40 \
--per_device_train_batch_size 4 \
Expand Down
2 changes: 1 addition & 1 deletion examples/prompt_tuning_twitter_complaints/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ torchrun \
--master_port=1234 \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--training_data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--peft_method pt \
--tokenizer_name_or_path $MODEL_PATH \
Expand Down
12 changes: 6 additions & 6 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MODEL_NAME = "Maykeye/TinyLLama-v0"
BASE_PEFT_KWARGS = {
"model_name_or_path": MODEL_NAME,
"data_path": TWITTER_COMPLAINTS_DATA,
"training_data_path": TWITTER_COMPLAINTS_DATA,
"num_train_epochs": 5,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
Expand All @@ -51,7 +51,7 @@
"dataset_text_field": "output",
"use_flash_attn": False,
"torch_dtype": "float32",
"model_max_length": 4096,
"max_seq_length": 4096,
"peft_method": "pt",
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
Expand Down Expand Up @@ -80,12 +80,12 @@ def test_helper_causal_lm_train_kwargs():
assert model_args.use_flash_attn == False
assert model_args.torch_dtype == "float32"

assert data_args.data_path == TWITTER_COMPLAINTS_DATA
assert data_args.training_data_path == TWITTER_COMPLAINTS_DATA
assert data_args.response_template == "\n### Label:"
assert data_args.dataset_text_field == "output"

assert training_args.num_train_epochs == 5
assert training_args.model_max_length == 4096
assert training_args.max_seq_length == 4096
assert training_args.save_strategy == "epoch"

assert tune_config.prompt_tuning_init == "RANDOM"
Expand All @@ -105,10 +105,10 @@ def test_run_train_requires_output_dir():
sft_trainer.train(model_args, data_args, training_args, tune_config)


def test_run_train_fails_data_path_not_exist():
def test_run_train_fails_training_data_path_not_exist():
"""Check fails when data path not found."""
updated_output_path = copy.deepcopy(BASE_PEFT_KWARGS)
updated_output_path["data_path"] = "fake/path"
updated_output_path["training_data_path"] = "fake/path"
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
updated_output_path
)
Expand Down
4 changes: 2 additions & 2 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ModelArguments:

@dataclass
class DataArguments:
data_path: str = field(
training_data_path: str = field(
default=None, metadata={"help": "Path to the training data in JSONL format."}
)
response_template: str = field(
Expand All @@ -61,7 +61,7 @@ class DataArguments:
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
# optim: str = field(default=DEFAULT_OPTIMIZER)
model_max_length: int = field(
max_seq_length: int = field(
default=DEFAULT_CONTEXT_LENGTH,
metadata={
"help": "Maximum sequence length. Sequences will be right padded \
Expand Down
19 changes: 8 additions & 11 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,14 @@ def train(
response_template_ids = tokenizer.encode(
data_args.response_template, add_special_tokens=False
)[2:]
# TODO: This is actually max_seq_length and not model_max_length. we should not override
# model_max_length as in current main. We need to change name of this parameter we expose
# to users.
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
logger.info("Model max length %s, model_max_length")
if train_args.model_max_length > tokenizer.model_max_length:

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
if train_args.max_seq_length > tokenizer.model_max_length:
logger.warning(
"model_max_length %s exceeds tokenizer.model_max_length \
"max_seq_length %s exceeds tokenizer.model_max_length \
%s, using tokenizer.model_max_length %s",
train_args.model_max_length,
train_args.max_seq_length,
tokenizer.model_max_length,
tokenizer.model_max_length,
)
Expand Down Expand Up @@ -197,8 +195,7 @@ def train(
)

# load the data by parsing JSON
# TODO: update arg from data_path to training_data_path since we also have validation_data_path
data_files = {"train": data_args.data_path}
data_files = {"train": data_args.training_data_path}
if data_args.validation_data_path:
data_files["validation"] = data_args.validation_data_path

Expand Down Expand Up @@ -256,7 +253,7 @@ def train(
data_collator=data_collator,
dataset_text_field=data_args.dataset_text_field,
args=train_args,
max_seq_length=model_max_length,
max_seq_length=max_seq_length,
callbacks=callbacks,
peft_config=peft_config,
)
Expand Down
Loading