Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: send logs to a file specified in LOG_FILE environment variable #327

Draft
wants to merge 4 commits into
base: wca
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
442 changes: 441 additions & 1 deletion README.md

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dependencies = [
"trl>=0.9.3,<1.0",
"peft>=0.8.0,<0.13",
"datasets>=2.15.0,<3.0",
"fire>=0.5.0,<1.0",
"simpleeval>=0.9.13,<1.0",
]

Expand Down
520 changes: 520 additions & 0 deletions tests/data/twitter_complaints_small.jsonl

Large diffs are not rendered by default.

Large diffs are not rendered by default.

35 changes: 21 additions & 14 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# Third Party
from datasets.exceptions import DatasetGenerationError
from transformers.trainer_callback import TrainerCallback
from transformers.utils import logging
import pytest
import torch
import transformers
Expand All @@ -48,15 +49,13 @@
from tuning.config import configs, peft_config
from tuning.config.tracker_configs import FileLoggingTrackerConfig

MODEL_ARGS = configs.ModelArguments(
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
)
DATA_ARGS = configs.DataArguments(
MODEL_ARGS = configs.ModelDataArguments(
model_name_or_path=MODEL_NAME,
use_flash_attn=False,
torch_dtype="float32",
training_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
response_template="\n### Label:",
dataset_text_field="output",
)
TRAIN_ARGS = configs.TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
Expand All @@ -72,6 +71,8 @@
save_strategy="epoch",
output_dir="tmp",
)
DATA_ARGS = MODEL_ARGS
TRAIN_ARGS = MODEL_ARGS
PEFT_PT_ARGS = peft_config.PromptTuningConfig(
prompt_tuning_init="RANDOM",
num_virtual_tokens=8,
Expand All @@ -80,6 +81,8 @@

PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)

logger = logging.get_logger("sft_trainer")


def test_run_train_requires_output_dir():
"""Check fails when output dir not provided."""
Expand Down Expand Up @@ -118,8 +121,6 @@ def test_parse_arguments(job_config):
job_config_copy = copy.deepcopy(job_config)
(
model_args,
data_args,
training_args,
_,
tune_config,
_,
Expand All @@ -129,8 +130,8 @@ def test_parse_arguments(job_config):
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
assert training_args.output_dir == "bloom-twitter"
assert model_args.dataset_text_field == "output"
assert model_args.output_dir == "bloom-twitter"
assert tune_config is None


Expand All @@ -140,19 +141,19 @@ def test_parse_arguments_defaults(job_config):
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] is False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
model_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
assert training_args.save_strategy.value == "epoch"
assert model_args.save_strategy.value == "epoch"


def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)
Expand Down Expand Up @@ -514,6 +515,9 @@ def test_run_causallm_ft_pretokenized(dataset_path):
# update the training data path to tokenized data
data_formatting_args.training_data_path = dataset_path

# call this to do some post processing over the data arguments
data_formatting_args.__post_init__()

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

Expand Down Expand Up @@ -861,9 +865,12 @@ def test_pretokenized_dataset(dataset_path):
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
data_args = copy.deepcopy(DATA_ARGS)
data_args.input_feature = "input"
data_args.output_feature = "output"
data_args.dataset_text_field = None
data_args.response_template = None
data_args.training_data_path = dataset_path
data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL
data_args.__post_init__()
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_PT_ARGS)
_validate_training(tempdir)

Expand Down
4 changes: 2 additions & 2 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
class InputData:
"""Stores the operation handler instance and corresponding action"""

args: config.TrainingArguments
args: config.ModelDataArguments
states: List[TrainerState]
metrics: dict

Expand All @@ -58,7 +58,7 @@ def _setup_data() -> InputData:
# Test data to mimic the fields of trainer loop log-lines
# trainer arguments and the initial state
return InputData(
args=config.TrainingArguments(
args=config.ModelDataArguments(
output_dir="",
logging_strategy=IntervalStrategy.STEPS,
logging_steps=1,
Expand Down
Loading