From a38b35fe53cd95acdb834affc059e78e8c208af3 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Fri, 29 Mar 2024 22:14:16 -0700 Subject: [PATCH 01/10] Allow for default params to be set Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- build/launch_training.py | 58 +++++++++--------------------------- tuning/utils/config_utils.py | 46 +++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 45 deletions(-) diff --git a/build/launch_training.py b/build/launch_training.py index 6fd142cc9..caba03e5d 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -28,12 +28,11 @@ # First Party import logging + +# Local from tuning import sft_trainer -from tuning.config import configs, peft_config from tuning.utils.merge_model_utils import create_merged_model - -# Third Party -import transformers +from tuning.utils.config_utils import post_process_job_config def txt_to_obj(txt): @@ -67,60 +66,31 @@ def main(): logging.basicConfig(level=LOGLEVEL) logging.info("Attempting to launch training script") - parser = transformers.HfArgumentParser( - dataclass_types=( - configs.ModelArguments, - configs.DataArguments, - configs.TrainingArguments, - peft_config.LoraConfig, - peft_config.PromptTuningConfig, - ) - ) - peft_method_parsed = "pt" + json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH") json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR") # accepts either path to JSON file or encoded string config if json_path: - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_json_file(json_path, allow_extra_keys=True) - - contents = "" with open(json_path, "r", encoding="utf-8") as f: - contents = json.load(f) - peft_method_parsed = contents.get("peft_method") - logging.debug("Input params parsed: %s", contents) + job_config_dict = json.load(f) elif json_env_var: job_config_dict = txt_to_obj(json_env_var) - logging.debug("Input params parsed: %s", job_config_dict) - - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) - - peft_method_parsed = job_config_dict.get("peft_method") else: raise ValueError( "Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \ or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'." ) - tune_config = None - merge_model = False - if peft_method_parsed == "lora": - tune_config = lora_config - merge_model = True - elif peft_method_parsed == "pt": - tune_config = prompt_tuning_config + logging.debug("Input params parsed: %s", job_config_dict) + + ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + ) = post_process_job_config(job_config_dict) logging.debug( "Parameters used to launch training: \ diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index fc7b7b46f..8f43ffbba 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -17,9 +17,16 @@ # Third Party from peft import LoraConfig, PromptTuningConfig +import transformers # Local -from tuning.config import peft_config +from tuning.config import configs, peft_config + +JOB_CONFIG_DEFAULTS_MAP = { + "torch_dtype": "bfloat16", + "save_strategy": "epoch", + "use_flash_attn": True, +} def update_config(config, **kwargs): @@ -87,3 +94,40 @@ def get_hf_peft_config(task_type, tuning_config): hf_peft_config = None # full parameter tuning return hf_peft_config + + +def post_process_job_config(job_config_dict): + parser = transformers.HfArgumentParser( + dataclass_types=( + configs.ModelArguments, + configs.DataArguments, + configs.TrainingArguments, + peft_config.LoraConfig, + peft_config.PromptTuningConfig, + ) + ) + peft_method_parsed = "pt" + + for key, val in JOB_CONFIG_DEFAULTS_MAP.items(): + if key not in job_config_dict: + job_config_dict[key] = val + + ( + model_args, + data_args, + training_args, + lora_config, + prompt_tuning_config, + ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) + + peft_method_parsed = job_config_dict.get("peft_method") + + tune_config = None + merge_model = False + if peft_method_parsed == "lora": + tune_config = lora_config + merge_model = True + elif peft_method_parsed == "pt": + tune_config = prompt_tuning_config + + return model_args, data_args, training_args, tune_config, merge_model From 9f2f8b4e618fb272f317260d49d9eb1a2c15b4bf Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Fri, 29 Mar 2024 22:30:20 -0700 Subject: [PATCH 02/10] Add tests Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- tests/dummy_job_config.json | 34 ++++++++++++++ tests/utils/test_config_utils.py | 78 ++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 tests/dummy_job_config.json create mode 100644 tests/utils/test_config_utils.py diff --git a/tests/dummy_job_config.json b/tests/dummy_job_config.json new file mode 100644 index 000000000..2bc1ea9f8 --- /dev/null +++ b/tests/dummy_job_config.json @@ -0,0 +1,34 @@ +{ + "accelerate_launch_args": { + "use_fsdp": true, + "env": ["env1", "env2"], + "dynamo_use_dynamic": true, + "num_machines": 1, + "num_processes":2, + "main_process_port": 1234, + "fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", + "fsdp_sharding_strategy": 1, + "fsdp_state_dict_type": "FULL_STATE_DICT", + "fsdp_cpu_ram_efficient_loading": true, + "fsdp_sync_module_states": true, + "config_file": "fixtures/accelerate_fsdp_defaults.yaml" + }, + "multi_gpu": true, + "model_name_or_path": "bigscience/bloom-560m", + "training_data_path": "data/twitter_complaints_small.json", + "output_dir": "bloom-twitter", + "num_train_epochs": 5.0, + "per_device_train_batch_size": 4, + "per_device_eval_batch_size": 4, + "gradient_accumulation_steps": 4, + "learning_rate": 0.03, + "weight_decay": 0.000001, + "lr_scheduler_type": "cosine", + "logging_steps": 1.0, + "packing": false, + "include_tokens_per_second": true, + "response_template": "### Label:", + "dataset_text_field": "output", + "use_flash_attn": false, + "tokenizer_name_or_path": "bigscience/bloom-560m" + } \ No newline at end of file diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py new file mode 100644 index 000000000..8b38b4f97 --- /dev/null +++ b/tests/utils/test_config_utils.py @@ -0,0 +1,78 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import copy +import json +import os + +# Third Party +import pytest + +# Local +from tuning.config.peft_config import LoraConfig, PromptTuningConfig +from tuning.utils.config_utils import post_process_job_config + +HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( + os.path.dirname(__file__), "..", "dummy_job_config.json" +) + + +# Note: job_config dict gets modified during post_process_job_config +@pytest.fixture(scope="session") +def job_config(): + with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: + dummy_job_config_dict = json.load(f) + return dummy_job_config_dict + + +def test_post_process_job_config(job_config): + job_config_copy = copy.deepcopy(job_config) + ( + model_args, + data_args, + training_args, + tune_config, + merge_model, + ) = post_process_job_config(job_config_copy) + assert model_args.torch_dtype == "bfloat16" + assert data_args.dataset_text_field == "output" + assert training_args.output_dir == "bloom-twitter" + assert tune_config == None + assert merge_model == False + + +def test_post_process_job_config_defaults(job_config): + job_config_defaults = copy.deepcopy(job_config) + assert "torch_dtype" not in job_config_defaults + assert job_config_defaults["use_flash_attn"] == False + assert "save_strategy" not in job_config_defaults + model_args, _, training_args, _, _ = post_process_job_config(job_config_defaults) + assert model_args.torch_dtype == "bfloat16" + assert model_args.use_flash_attn == False + assert training_args.save_strategy.value == "epoch" + + +def test_post_process_job_config_peft_method(job_config): + job_config_pt = copy.deepcopy(job_config) + job_config_pt["peft_method"] = "pt" + _, _, _, tune_config, merge_model = post_process_job_config(job_config_pt) + assert type(tune_config) == PromptTuningConfig + assert merge_model == False + + job_config_lora = copy.deepcopy(job_config) + job_config_lora["peft_method"] = "lora" + _, _, _, tune_config, merge_model = post_process_job_config(job_config_lora) + assert type(tune_config) == LoraConfig + assert merge_model == True From 568f47febcae963b342b6cea7ec623a0c012c3f1 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Sat, 30 Mar 2024 00:15:27 -0700 Subject: [PATCH 03/10] Simplifying default params logic Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- tests/utils/test_config_utils.py | 4 ++-- tuning/config/configs.py | 9 +++++++++ tuning/utils/config_utils.py | 10 ---------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py index 8b38b4f97..bce6bdf43 100644 --- a/tests/utils/test_config_utils.py +++ b/tests/utils/test_config_utils.py @@ -46,7 +46,7 @@ def test_post_process_job_config(job_config): tune_config, merge_model, ) = post_process_job_config(job_config_copy) - assert model_args.torch_dtype == "bfloat16" + assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" assert training_args.output_dir == "bloom-twitter" assert tune_config == None @@ -59,7 +59,7 @@ def test_post_process_job_config_defaults(job_config): assert job_config_defaults["use_flash_attn"] == False assert "save_strategy" not in job_config_defaults model_args, _, training_args, _, _ = post_process_job_config(job_config_defaults) - assert model_args.torch_dtype == "bfloat16" + assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn == False assert training_args.save_strategy.value == "epoch" diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 525e0c6ef..9bcdfe0c3 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -72,3 +72,12 @@ class TrainingArguments(transformers.TrainingArguments): default=False, metadata={"help": "Packing to be enabled in SFT Trainer, default is False"}, ) + save_strategy: str = field( + default="epoch", + metadata={ + "help": "The checkpoint save strategy to adopt during training. \ + Possible values are 'no'(no save is done during training), \ + 'epoch' (save is done at the end of each epoch), \ + 'steps' (save is done every `save_steps`)" + }, + ) diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 8f43ffbba..ffac3c1ca 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -22,12 +22,6 @@ # Local from tuning.config import configs, peft_config -JOB_CONFIG_DEFAULTS_MAP = { - "torch_dtype": "bfloat16", - "save_strategy": "epoch", - "use_flash_attn": True, -} - def update_config(config, **kwargs): if isinstance(config, (tuple, list)): @@ -108,10 +102,6 @@ def post_process_job_config(job_config_dict): ) peft_method_parsed = "pt" - for key, val in JOB_CONFIG_DEFAULTS_MAP.items(): - if key not in job_config_dict: - job_config_dict[key] = val - ( model_args, data_args, From ee5ff7f7f19b35397cafd868dea5e88fe25d27f6 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Sun, 31 Mar 2024 16:29:15 -0700 Subject: [PATCH 04/10] Setting use_flash_attn default Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- tuning/config/configs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 9bcdfe0c3..f64c40e07 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -34,8 +34,8 @@ class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") use_flash_attn: bool = field( - default=True, - metadata={"help": "Use Flash attention v2 from transformers, default is True"}, + default=False, + metadata={"help": "Use Flash attention v2 from transformers, default is False"}, ) torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16 From 51cd4dfde7cc59cce929495e26a9d83bcb680636 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Sun, 31 Mar 2024 16:50:24 -0700 Subject: [PATCH 05/10] Formatting Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- build/launch_training.py | 9 --------- tests/utils/__init__.py | 13 +++++++++++++ tuning/utils/config_utils.py | 20 ++++++++++++++++++++ 3 files changed, 33 insertions(+), 9 deletions(-) create mode 100644 tests/utils/__init__.py diff --git a/build/launch_training.py b/build/launch_training.py index caba03e5d..f602ff190 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -92,15 +92,6 @@ def main(): merge_model, ) = post_process_job_config(job_config_dict) - logging.debug( - "Parameters used to launch training: \ - model_args %s, data_args %s, training_args %s, tune_config %s", - model_args, - data_args, - training_args, - tune_config, - ) - original_output_dir = training_args.output_dir with tempfile.TemporaryDirectory() as tempdir: training_args.output_dir = tempdir diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index ffac3c1ca..95f44b15f 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -14,6 +14,7 @@ # Standard from dataclasses import asdict +import logging # Third Party from peft import LoraConfig, PromptTuningConfig @@ -91,6 +92,16 @@ def get_hf_peft_config(task_type, tuning_config): def post_process_job_config(job_config_dict): + """Return parsed config for tuning to pass to SFT Trainer + Args: + job_config_dict: dict + Return: + model_args: configs.ModelArguments + data_args: configs.DataArguments + training_args: configs.TrainingArguments + tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig + merge_model: bool + """ parser = transformers.HfArgumentParser( dataclass_types=( configs.ModelArguments, @@ -120,4 +131,13 @@ def post_process_job_config(job_config_dict): elif peft_method_parsed == "pt": tune_config = prompt_tuning_config + logging.debug( + "Parameters used to launch training: \ + model_args %s, data_args %s, training_args %s, tune_config %s", + model_args, + data_args, + training_args, + tune_config, + ) + return model_args, data_args, training_args, tune_config, merge_model From bb8114a8bb2fdbaba34e5332219adc0832631207 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 2 Apr 2024 00:41:20 -0700 Subject: [PATCH 06/10] Address review comments Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- build/launch_training.py | 4 +- build/utils.py | 74 ++++++++++++++++++++++++++++++++++++ tuning/config/configs.py | 9 +++++ tuning/utils/config_utils.py | 56 +-------------------------- 4 files changed, 86 insertions(+), 57 deletions(-) create mode 100644 build/utils.py diff --git a/build/launch_training.py b/build/launch_training.py index f602ff190..07dfaf0c7 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -32,7 +32,7 @@ # Local from tuning import sft_trainer from tuning.utils.merge_model_utils import create_merged_model -from tuning.utils.config_utils import post_process_job_config +from build.utils import process_launch_training_args def txt_to_obj(txt): @@ -90,7 +90,7 @@ def main(): training_args, tune_config, merge_model, - ) = post_process_job_config(job_config_dict) + ) = process_launch_training_args(job_config_dict) original_output_dir = training_args.output_dir with tempfile.TemporaryDirectory() as tempdir: diff --git a/build/utils.py b/build/utils.py new file mode 100644 index 000000000..0d8c2a155 --- /dev/null +++ b/build/utils.py @@ -0,0 +1,74 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging + +# Third Party +import transformers + +# Local +from tuning.config import configs, peft_config + + +def process_launch_training_args(job_config_dict): + """Return parsed config for tuning to pass to SFT Trainer + Args: + job_config_dict: dict + Return: + model_args: configs.ModelArguments + data_args: configs.DataArguments + training_args: configs.TrainingArguments + tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig + merge_model: bool + """ + parser = transformers.HfArgumentParser( + dataclass_types=( + configs.ModelArguments, + configs.DataArguments, + configs.TrainingArguments, + peft_config.LoraConfig, + peft_config.PromptTuningConfig, + ) + ) + peft_method_parsed = "pt" + + ( + model_args, + data_args, + training_args, + lora_config, + prompt_tuning_config, + ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) + + peft_method_parsed = job_config_dict.get("peft_method") + + tune_config = None + merge_model = False + if peft_method_parsed == "lora": + tune_config = lora_config + merge_model = True + elif peft_method_parsed == "pt": + tune_config = prompt_tuning_config + + logging.debug( + "Parameters used to launch training: \ + model_args %s, data_args %s, training_args %s, tune_config %s", + model_args, + data_args, + training_args, + tune_config, + ) + + return model_args, data_args, training_args, tune_config, merge_model diff --git a/tuning/config/configs.py b/tuning/config/configs.py index f64c40e07..73ca0b0ef 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -81,3 +81,12 @@ class TrainingArguments(transformers.TrainingArguments): 'steps' (save is done every `save_steps`)" }, ) + logging_strategy: str = field( + default="epoch", + metadata={ + "help": "The logging strategy to adopt during training. \ + Possible values are 'no'(no logging is done during training), \ + 'epoch' (logging is done at the end of each epoch), \ + 'steps' (logging is done every `logging_steps`)" + }, + ) diff --git a/tuning/utils/config_utils.py b/tuning/utils/config_utils.py index 95f44b15f..fc7b7b46f 100644 --- a/tuning/utils/config_utils.py +++ b/tuning/utils/config_utils.py @@ -14,14 +14,12 @@ # Standard from dataclasses import asdict -import logging # Third Party from peft import LoraConfig, PromptTuningConfig -import transformers # Local -from tuning.config import configs, peft_config +from tuning.config import peft_config def update_config(config, **kwargs): @@ -89,55 +87,3 @@ def get_hf_peft_config(task_type, tuning_config): hf_peft_config = None # full parameter tuning return hf_peft_config - - -def post_process_job_config(job_config_dict): - """Return parsed config for tuning to pass to SFT Trainer - Args: - job_config_dict: dict - Return: - model_args: configs.ModelArguments - data_args: configs.DataArguments - training_args: configs.TrainingArguments - tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig - merge_model: bool - """ - parser = transformers.HfArgumentParser( - dataclass_types=( - configs.ModelArguments, - configs.DataArguments, - configs.TrainingArguments, - peft_config.LoraConfig, - peft_config.PromptTuningConfig, - ) - ) - peft_method_parsed = "pt" - - ( - model_args, - data_args, - training_args, - lora_config, - prompt_tuning_config, - ) = parser.parse_dict(job_config_dict, allow_extra_keys=True) - - peft_method_parsed = job_config_dict.get("peft_method") - - tune_config = None - merge_model = False - if peft_method_parsed == "lora": - tune_config = lora_config - merge_model = True - elif peft_method_parsed == "pt": - tune_config = prompt_tuning_config - - logging.debug( - "Parameters used to launch training: \ - model_args %s, data_args %s, training_args %s, tune_config %s", - model_args, - data_args, - training_args, - tune_config, - ) - - return model_args, data_args, training_args, tune_config, merge_model From 908a3b5513676ef0fb4f89a4524f77f7a859aaac Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 2 Apr 2024 01:00:35 -0700 Subject: [PATCH 07/10] Moving tests Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- pytest.ini | 4 ++++ tests/build/__init__.py | 13 ++++++++++++ .../test_utils.py} | 20 ++++++++++--------- 3 files changed, 28 insertions(+), 9 deletions(-) create mode 100644 pytest.ini create mode 100644 tests/build/__init__.py rename tests/{utils/test_config_utils.py => build/test_utils.py} (77%) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..a27fc867e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +# Register tests from `build` dir, removing `build` from norecursedirs default list, +# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs +norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch} \ No newline at end of file diff --git a/tests/build/__init__.py b/tests/build/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/build/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/utils/test_config_utils.py b/tests/build/test_utils.py similarity index 77% rename from tests/utils/test_config_utils.py rename to tests/build/test_utils.py index bce6bdf43..51df3da7d 100644 --- a/tests/utils/test_config_utils.py +++ b/tests/build/test_utils.py @@ -22,14 +22,14 @@ # Local from tuning.config.peft_config import LoraConfig, PromptTuningConfig -from tuning.utils.config_utils import post_process_job_config +from build.utils import process_launch_training_args HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( os.path.dirname(__file__), "..", "dummy_job_config.json" ) -# Note: job_config dict gets modified during post_process_job_config +# Note: job_config dict gets modified during process_launch_training_args @pytest.fixture(scope="session") def job_config(): with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: @@ -37,7 +37,7 @@ def job_config(): return dummy_job_config_dict -def test_post_process_job_config(job_config): +def test_process_launch_training_args(job_config): job_config_copy = copy.deepcopy(job_config) ( model_args, @@ -45,7 +45,7 @@ def test_post_process_job_config(job_config): training_args, tune_config, merge_model, - ) = post_process_job_config(job_config_copy) + ) = process_launch_training_args(job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" assert training_args.output_dir == "bloom-twitter" @@ -53,26 +53,28 @@ def test_post_process_job_config(job_config): assert merge_model == False -def test_post_process_job_config_defaults(job_config): +def test_process_launch_training_args_defaults(job_config): job_config_defaults = copy.deepcopy(job_config) assert "torch_dtype" not in job_config_defaults assert job_config_defaults["use_flash_attn"] == False assert "save_strategy" not in job_config_defaults - model_args, _, training_args, _, _ = post_process_job_config(job_config_defaults) + model_args, _, training_args, _, _ = process_launch_training_args( + job_config_defaults + ) assert str(model_args.torch_dtype) == "torch.bfloat16" assert model_args.use_flash_attn == False assert training_args.save_strategy.value == "epoch" -def test_post_process_job_config_peft_method(job_config): +def test_process_launch_training_args_peft_method(job_config): job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, tune_config, merge_model = post_process_job_config(job_config_pt) + _, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt) assert type(tune_config) == PromptTuningConfig assert merge_model == False job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, tune_config, merge_model = post_process_job_config(job_config_lora) + _, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora) assert type(tune_config) == LoraConfig assert merge_model == True From 7bb43b7b56f595071fd5f52c2ca350a32c0a6df6 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 2 Apr 2024 12:51:38 -0700 Subject: [PATCH 08/10] Address review comments Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- tests/{ => build}/dummy_job_config.json | 0 tests/build/test_utils.py | 2 +- tuning/config/configs.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename tests/{ => build}/dummy_job_config.json (100%) diff --git a/tests/dummy_job_config.json b/tests/build/dummy_job_config.json similarity index 100% rename from tests/dummy_job_config.json rename to tests/build/dummy_job_config.json diff --git a/tests/build/test_utils.py b/tests/build/test_utils.py index 51df3da7d..8f1975da0 100644 --- a/tests/build/test_utils.py +++ b/tests/build/test_utils.py @@ -25,7 +25,7 @@ from build.utils import process_launch_training_args HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( - os.path.dirname(__file__), "..", "dummy_job_config.json" + os.path.dirname(__file__), "dummy_job_config.json" ) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 73ca0b0ef..7519f000c 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -34,8 +34,8 @@ class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") use_flash_attn: bool = field( - default=False, - metadata={"help": "Use Flash attention v2 from transformers, default is False"}, + default=True, + metadata={"help": "Use Flash attention v2 from transformers, default is True"}, ) torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16 From 4cf6e5ad886a9adcc29a7265c5529312c2a9ce79 Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 2 Apr 2024 13:28:29 -0700 Subject: [PATCH 09/10] Address review comment Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- build/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/build/utils.py b/build/utils.py index 0d8c2a155..983977388 100644 --- a/build/utils.py +++ b/build/utils.py @@ -42,7 +42,6 @@ def process_launch_training_args(job_config_dict): peft_config.PromptTuningConfig, ) ) - peft_method_parsed = "pt" ( model_args, From 9a7936873542c5505c9cfe3a6822c63d16a65dad Mon Sep 17 00:00:00 2001 From: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> Date: Tue, 2 Apr 2024 13:46:24 -0700 Subject: [PATCH 10/10] Fix merge conflicts Signed-off-by: Thara Palanivel <130496890+tharapalanivel@users.noreply.github.com> --- build/launch_training.py | 2 +- build/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/build/launch_training.py b/build/launch_training.py index 07dfaf0c7..5879a11d4 100644 --- a/build/launch_training.py +++ b/build/launch_training.py @@ -65,7 +65,7 @@ def main(): LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() logging.basicConfig(level=LOGLEVEL) - logging.info("Attempting to launch training script") + logging.info("Initializing launch training script") json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH") json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR") diff --git a/build/utils.py b/build/utils.py index 983977388..090cd922a 100644 --- a/build/utils.py +++ b/build/utils.py @@ -61,7 +61,7 @@ def process_launch_training_args(job_config_dict): elif peft_method_parsed == "pt": tune_config = prompt_tuning_config - logging.debug( + logging.info( "Parameters used to launch training: \ model_args %s, data_args %s, training_args %s, tune_config %s", model_args,