Skip to content

Commit

Permalink
Merge pull request #153 from dushyantbehl/fix-pr89-launch-training
Browse files Browse the repository at this point in the history
fix: launch_training.py arguments with new tracker api
  • Loading branch information
anhuong authored May 10, 2024
2 parents bc1c038 + e46e58c commit 4be666d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
17 changes: 15 additions & 2 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# Local
from tuning import sft_trainer
from tuning.utils.merge_model_utils import create_merged_model
from tuning.config.tracker_configs import TrackerConfigFactory
from build.utils import process_launch_training_args, get_job_config


Expand Down Expand Up @@ -62,12 +63,23 @@ def main():
training_args,
tune_config,
merge_model,
file_logger_config,
aim_config,
) = process_launch_training_args(job_config)

original_output_dir = training_args.output_dir
with tempfile.TemporaryDirectory() as tempdir:
training_args.output_dir = tempdir
sft_trainer.train(model_args, data_args, training_args, tune_config)
tracker_config_args = TrackerConfigFactory(
file_logger_config=file_logger_config, aim_config=aim_config
)
sft_trainer.train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
tracker_configs=tracker_config_args,
)

if merge_model:
export_path = os.getenv(
Expand Down Expand Up @@ -108,7 +120,8 @@ def main():

# copy over any loss logs
train_logs_filepath = os.path.join(
training_args.output_dir, sft_trainer.TRAINING_LOGS_FILENAME
training_args.output_dir,
tracker_config_args.file_logger_config.training_logs_filename,
)
if os.path.exists(train_logs_filepath):
shutil.copy(train_logs_filepath, original_output_dir)
Expand Down
23 changes: 20 additions & 3 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from accelerate.commands.launch import launch_command_parser

# Local
from tuning.config import configs, peft_config
from tuning.config import configs, peft_config, tracker_configs


def txt_to_obj(txt):
Expand Down Expand Up @@ -67,6 +67,8 @@ def process_launch_training_args(job_config_dict):
training_args: configs.TrainingArguments
tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig
merge_model: bool
file_logger_config: tracker_configs.FileLoggingTrackerConfig
aim_config: tracker_configs.AimConfig
"""
parser = transformers.HfArgumentParser(
dataclass_types=(
Expand All @@ -75,6 +77,8 @@ def process_launch_training_args(job_config_dict):
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
tracker_configs.FileLoggingTrackerConfig,
tracker_configs.AimConfig,
)
)

Expand All @@ -84,6 +88,8 @@ def process_launch_training_args(job_config_dict):
training_args,
lora_config,
prompt_tuning_config,
file_logger_config,
aim_config,
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)

peft_method_parsed = job_config_dict.get("peft_method")
Expand All @@ -98,14 +104,25 @@ def process_launch_training_args(job_config_dict):

logging.info(
"Parameters used to launch training: \
model_args %s, data_args %s, training_args %s, tune_config %s",
model_args %s, data_args %s, training_args %s, tune_config %s \
file_logger_config %s aim_config %s",
model_args,
data_args,
training_args,
tune_config,
file_logger_config,
aim_config,
)

return model_args, data_args, training_args, tune_config, merge_model
return (
model_args,
data_args,
training_args,
tune_config,
merge_model,
file_logger_config,
aim_config,
)


def process_accelerate_launch_args(job_config_dict):
Expand Down
12 changes: 9 additions & 3 deletions tests/build/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def test_process_launch_training_args(job_config):
training_args,
tune_config,
merge_model,
_,
_,
) = process_launch_training_args(job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -59,7 +61,7 @@ def test_process_launch_training_args_defaults(job_config):
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] is False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _ = process_launch_training_args(
model_args, _, training_args, _, _, _, _ = process_launch_training_args(
job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
Expand All @@ -70,13 +72,17 @@ def test_process_launch_training_args_defaults(job_config):
def test_process_launch_training_args_peft_method(job_config):
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt)
_, _, _, tune_config, merge_model, _, _ = process_launch_training_args(
job_config_pt
)
assert isinstance(tune_config, PromptTuningConfig)
assert merge_model is False

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora)
_, _, _, tune_config, merge_model, _, _ = process_launch_training_args(
job_config_lora
)
assert isinstance(tune_config, LoraConfig)
assert merge_model is True

Expand Down

0 comments on commit 4be666d

Please sign in to comment.