Skip to content

Commit

Permalink
Setting default values in training job config (#104)
Browse files Browse the repository at this point in the history
* Allow for default params to be set

Signed-off-by: Thara Palanivel <[email protected]>

* Add tests

Signed-off-by: Thara Palanivel <[email protected]>

* Simplifying default params logic

Signed-off-by: Thara Palanivel <[email protected]>

* Setting use_flash_attn default

Signed-off-by: Thara Palanivel <[email protected]>

* Formatting

Signed-off-by: Thara Palanivel <[email protected]>

* Address review comments

Signed-off-by: Thara Palanivel <[email protected]>

* Moving tests

Signed-off-by: Thara Palanivel <[email protected]>

* Address review comments

Signed-off-by: Thara Palanivel <[email protected]>

* Address review comment

Signed-off-by: Thara Palanivel <[email protected]>

* Fix merge conflicts

Signed-off-by: Thara Palanivel <[email protected]>

---------

Signed-off-by: Thara Palanivel <[email protected]>
  • Loading branch information
tharapalanivel authored Apr 2, 2024
1 parent 2df20ba commit 3785363
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 49 deletions.
58 changes: 9 additions & 49 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@

# First Party
import logging

# Local
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.utils.merge_model_utils import create_merged_model

# Third Party
import transformers
from build.utils import process_launch_training_args


def txt_to_obj(txt):
Expand Down Expand Up @@ -68,69 +67,30 @@ def main():

logging.info("Initializing launch training script")

parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)
peft_method_parsed = "pt"
json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")

# accepts either path to JSON file or encoded string config
if json_path:
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_json_file(json_path, allow_extra_keys=True)

contents = ""
with open(json_path, "r", encoding="utf-8") as f:
contents = json.load(f)
peft_method_parsed = contents.get("peft_method")
logging.debug("Input params parsed: %s", contents)
job_config_dict = json.load(f)
elif json_env_var:
job_config_dict = txt_to_obj(json_env_var)
logging.debug("Input params parsed: %s", job_config_dict)

(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)

peft_method_parsed = job_config_dict.get("peft_method")
else:
raise ValueError(
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)

tune_config = None
merge_model = False
if peft_method_parsed == "lora":
tune_config = lora_config
merge_model = True
elif peft_method_parsed == "pt":
tune_config = prompt_tuning_config

logging.info(
"Parameters used to launch training: \
model_args %s, data_args %s, training_args %s, tune_config %s",
logging.debug("Input params parsed: %s", job_config_dict)

(
model_args,
data_args,
training_args,
tune_config,
)
merge_model,
) = process_launch_training_args(job_config_dict)

original_output_dir = training_args.output_dir
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
73 changes: 73 additions & 0 deletions build/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import logging

# Third Party
import transformers

# Local
from tuning.config import configs, peft_config


def process_launch_training_args(job_config_dict):
"""Return parsed config for tuning to pass to SFT Trainer
Args:
job_config_dict: dict
Return:
model_args: configs.ModelArguments
data_args: configs.DataArguments
training_args: configs.TrainingArguments
tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig
merge_model: bool
"""
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)

(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)

peft_method_parsed = job_config_dict.get("peft_method")

tune_config = None
merge_model = False
if peft_method_parsed == "lora":
tune_config = lora_config
merge_model = True
elif peft_method_parsed == "pt":
tune_config = prompt_tuning_config

logging.info(
"Parameters used to launch training: \
model_args %s, data_args %s, training_args %s, tune_config %s",
model_args,
data_args,
training_args,
tune_config,
)

return model_args, data_args, training_args, tune_config, merge_model
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
# Register tests from `build` dir, removing `build` from norecursedirs default list,
# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs
norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch}
13 changes: 13 additions & 0 deletions tests/build/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
34 changes: 34 additions & 0 deletions tests/build/dummy_job_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"accelerate_launch_args": {
"use_fsdp": true,
"env": ["env1", "env2"],
"dynamo_use_dynamic": true,
"num_machines": 1,
"num_processes":2,
"main_process_port": 1234,
"fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_cpu_ram_efficient_loading": true,
"fsdp_sync_module_states": true,
"config_file": "fixtures/accelerate_fsdp_defaults.yaml"
},
"multi_gpu": true,
"model_name_or_path": "bigscience/bloom-560m",
"training_data_path": "data/twitter_complaints_small.json",
"output_dir": "bloom-twitter",
"num_train_epochs": 5.0,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
"gradient_accumulation_steps": 4,
"learning_rate": 0.03,
"weight_decay": 0.000001,
"lr_scheduler_type": "cosine",
"logging_steps": 1.0,
"packing": false,
"include_tokens_per_second": true,
"response_template": "### Label:",
"dataset_text_field": "output",
"use_flash_attn": false,
"tokenizer_name_or_path": "bigscience/bloom-560m"
}
80 changes: 80 additions & 0 deletions tests/build/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import copy
import json
import os

# Third Party
import pytest

# Local
from tuning.config.peft_config import LoraConfig, PromptTuningConfig
from build.utils import process_launch_training_args

HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
os.path.dirname(__file__), "dummy_job_config.json"
)


# Note: job_config dict gets modified during process_launch_training_args
@pytest.fixture(scope="session")
def job_config():
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
dummy_job_config_dict = json.load(f)
return dummy_job_config_dict


def test_process_launch_training_args(job_config):
job_config_copy = copy.deepcopy(job_config)
(
model_args,
data_args,
training_args,
tune_config,
merge_model,
) = process_launch_training_args(job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
assert training_args.output_dir == "bloom-twitter"
assert tune_config == None
assert merge_model == False


def test_process_launch_training_args_defaults(job_config):
job_config_defaults = copy.deepcopy(job_config)
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] == False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _ = process_launch_training_args(
job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn == False
assert training_args.save_strategy.value == "epoch"


def test_process_launch_training_args_peft_method(job_config):
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt)
assert type(tune_config) == PromptTuningConfig
assert merge_model == False

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora)
assert type(tune_config) == LoraConfig
assert merge_model == True
13 changes: 13 additions & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18 changes: 18 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)
save_strategy: str = field(
default="epoch",
metadata={
"help": "The checkpoint save strategy to adopt during training. \
Possible values are 'no'(no save is done during training), \
'epoch' (save is done at the end of each epoch), \
'steps' (save is done every `save_steps`)"
},
)
logging_strategy: str = field(
default="epoch",
metadata={
"help": "The logging strategy to adopt during training. \
Possible values are 'no'(no logging is done during training), \
'epoch' (logging is done at the end of each epoch), \
'steps' (logging is done every `logging_steps`)"
},
)

0 comments on commit 3785363

Please sign in to comment.