diff --git a/build/launch_training.py b/build/launch_training.py index 69d3605e2..5879a11d4 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -28,12 +28,11 @@ # First Party import logging + +# Local from tuning import sft_trainer -from tuning.config import configs, peft_config from tuning.utils.merge_model_utils import create_merged_model - -# Third Party -import transformers +from build.utils import process_launch_training_args def txt_to_obj(txt): @@ -68,69 +67,30 @@ def main(): logging.info("Initializing launch training script") - parser = transformers.HfArgumentParser( - dataclass_types=( - configs.ModelArguments, - configs.DataArguments, - configs.TrainingArguments, - peft_config.LoraConfig, - peft_config.PromptTuningConfig, - ) - ) - peft_method_parsed = "pt" json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH") json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR") # accepts either path to JSON file or encoded string config if json_path: - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_json_file(json_path, allow_extra_keys=True) - - contents = "" with open(json_path, "r", encoding="utf-8") as f: - contents = json.load(f) - peft_method_parsed = contents.get("peft_method") - logging.debug("Input params parsed: %s", contents) + job_config_dict = json.load(f) elif json_env_var: job_config_dict = txt_to_obj(json_env_var) - logging.debug("Input params parsed: %s", job_config_dict) - - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) - - peft_method_parsed = job_config_dict.get("peft_method") else: raise ValueError( "Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \ or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'." ) - tune_config = None - merge_model = False - if peft_method_parsed == "lora": - tune_config = lora_config - merge_model = True - elif peft_method_parsed == "pt": - tune_config = prompt_tuning_config - - logging.info( - "Parameters used to launch training: \ - model_args %s, data_args %s, training_args %s, tune_config %s", + logging.debug("Input params parsed: %s", job_config_dict) + + ( model_args, data_args, training_args, tune_config, - ) + merge_model, + ) = process_launch_training_args(job_config_dict) original_output_dir = training_args.output_dir with tempfile.TemporaryDirectory() as tempdir: diff --git a/build/utils.py b/build/utils.py new file mode 100644 index 000000000..090cd922a --- /dev/null +++ b/build/utils.py @@ -0,0 +1,73 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging + +# Third Party +import transformers + +# Local +from tuning.config import configs, peft_config + + +def process_launch_training_args(job_config_dict): + """Return parsed config for tuning to pass to SFT Trainer + Args: + job_config_dict: dict + Return: + model_args: configs.ModelArguments + data_args: configs.DataArguments + training_args: configs.TrainingArguments + tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig + merge_model: bool + """ + parser = transformers.HfArgumentParser( + dataclass_types=( + configs.ModelArguments, + configs.DataArguments, + configs.TrainingArguments, + peft_config.LoraConfig, + peft_config.PromptTuningConfig, + ) + ) + + ( + model_args, + data_args, + training_args, + lora_config, + prompt_tuning_config, + ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) + + peft_method_parsed = job_config_dict.get("peft_method") + + tune_config = None + merge_model = False + if peft_method_parsed == "lora": + tune_config = lora_config + merge_model = True + elif peft_method_parsed == "pt": + tune_config = prompt_tuning_config + + logging.info( + "Parameters used to launch training: \ + model_args %s, data_args %s, training_args %s, tune_config %s", + model_args, + data_args, + training_args, + tune_config, + ) + + return model_args, data_args, training_args, tune_config, merge_model diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..a27fc867e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +# Register tests from `build` dir, removing `build` from norecursedirs default list, +# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs +norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch} \ No newline at end of file diff --git a/tests/build/__init__.py b/tests/build/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/build/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/build/dummy_job_config.json b/tests/build/dummy_job_config.json new file mode 100644 index 000000000..2bc1ea9f8 --- /dev/null +++ b/tests/build/dummy_job_config.json @@ -0,0 +1,34 @@ +{ + "accelerate_launch_args": { + "use_fsdp": true, + "env": ["env1", "env2"], + "dynamo_use_dynamic": true, + "num_machines": 1, + "num_processes":2, + "main_process_port": 1234, + "fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_sync_module_states": true, + "config_file": "fixtures/accelerate_fsdp_defaults.yaml" + }, + "multi_gpu": true, + "model_name_or_path": "bigscience/bloom-560m", + "training_data_path": "data/twitter_complaints_small.json", + "output_dir": "bloom-twitter", + "num_train_epochs": 5.0, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "learning_rate": 0.03, + "weight_decay": 0.000001, + "lr_scheduler_type": "cosine", + "logging_steps": 1.0, + "packing": false, + "include_tokens_per_second": true, + "response_template": "### Label:", + "dataset_text_field": "output", + "use_flash_attn": false, + "tokenizer_name_or_path": "bigscience/bloom-560m" + } \ No newline at end of file diff --git a/tests/build/test_utils.py b/tests/build/test_utils.py new file mode 100644 index 000000000..8f1975da0 --- /dev/null +++ b/tests/build/test_utils.py @@ -0,0 +1,80 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import copy +import json +import os + +# Third Party +import pytest + +# Local +from tuning.config.peft_config import LoraConfig, PromptTuningConfig +from build.utils import process_launch_training_args + +HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( + os.path.dirname(__file__), "dummy_job_config.json" +) + + +# Note: job_config dict gets modified during process_launch_training_args +@pytest.fixture(scope="session") +def job_config(): + with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: + dummy_job_config_dict = json.load(f) + return dummy_job_config_dict + + +def test_process_launch_training_args(job_config): + job_config_copy = copy.deepcopy(job_config) + ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + ) = process_launch_training_args(job_config_copy) + assert str(model_args.torch_dtype) == "torch.bfloat16" + assert data_args.dataset_text_field == "output" + assert training_args.output_dir == "bloom-twitter" + assert tune_config == None + assert merge_model == False + + +def test_process_launch_training_args_defaults(job_config): + job_config_defaults = copy.deepcopy(job_config) + assert "torch_dtype" not in job_config_defaults + assert job_config_defaults["use_flash_attn"] == False + assert "save_strategy" not in job_config_defaults + model_args, _, training_args, _, _ = process_launch_training_args( + job_config_defaults + ) + assert str(model_args.torch_dtype) == "torch.bfloat16" + assert model_args.use_flash_attn == False + assert training_args.save_strategy.value == "epoch" + + +def test_process_launch_training_args_peft_method(job_config): + job_config_pt = copy.deepcopy(job_config) + job_config_pt["peft_method"] = "pt" + _, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt) + assert type(tune_config) == PromptTuningConfig + assert merge_model == False + + job_config_lora = copy.deepcopy(job_config) + job_config_lora["peft_method"] = "lora" + _, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora) + assert type(tune_config) == LoraConfig + assert merge_model == True diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 525e0c6ef..7519f000c 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -72,3 +72,21 @@ class TrainingArguments(transformers.TrainingArguments): default=False, metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) + save_strategy: str = field( + default="epoch", + metadata={ + "help": "The checkpoint save strategy to adopt during training. \ + Possible values are 'no'(no save is done during training), \ + 'epoch' (save is done at the end of each epoch), \ + 'steps' (save is done every `save_steps`)" + }, + ) + logging_strategy: str = field( + default="epoch", + metadata={ + "help": "The logging strategy to adopt during training. \ + Possible values are 'no'(no logging is done during training), \ + 'epoch' (logging is done at the end of each epoch), \ + 'steps' (logging is done every `logging_steps`)" + }, + )