Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting default values in training job config #104

Merged
merged 12 commits into from
Apr 2, 2024
59 changes: 10 additions & 49 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@

# First Party
import logging

# Local
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.utils.merge_model_utils import create_merged_model

# Third Party
import transformers
from build.utils import process_launch_training_args


def txt_to_obj(txt):
Expand Down Expand Up @@ -67,69 +66,31 @@ def main():
logging.basicConfig(level=LOGLEVEL)

logging.info("Attempting to launch training script")
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)
peft_method_parsed = "pt"

json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")

# accepts either path to JSON file or encoded string config
if json_path:
(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_json_file(json_path, allow_extra_keys=True)

contents = ""
with open(json_path, "r", encoding="utf-8") as f:
contents = json.load(f)
peft_method_parsed = contents.get("peft_method")
logging.debug("Input params parsed: %s", contents)
job_config_dict = json.load(f)
elif json_env_var:
job_config_dict = txt_to_obj(json_env_var)
Comment on lines 75 to 78
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: perhaps we should refactor this into build/utils as well since accelerate_launch and launch_training use the same method to parse the JSON

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I'll add that in the follow-up tests PR that is in the queue next

logging.debug("Input params parsed: %s", job_config_dict)

(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)

peft_method_parsed = job_config_dict.get("peft_method")
else:
raise ValueError(
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)

tune_config = None
merge_model = False
if peft_method_parsed == "lora":
tune_config = lora_config
merge_model = True
elif peft_method_parsed == "pt":
tune_config = prompt_tuning_config

logging.debug(
"Parameters used to launch training: \
model_args %s, data_args %s, training_args %s, tune_config %s",
logging.debug("Input params parsed: %s", job_config_dict)

(
model_args,
data_args,
training_args,
tune_config,
)
merge_model,
) = process_launch_training_args(job_config_dict)

original_output_dir = training_args.output_dir
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
73 changes: 73 additions & 0 deletions build/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import logging

# Third Party
import transformers

# Local
from tuning.config import configs, peft_config


def process_launch_training_args(job_config_dict):
"""Return parsed config for tuning to pass to SFT Trainer
Args:
job_config_dict: dict
Return:
model_args: configs.ModelArguments
data_args: configs.DataArguments
training_args: configs.TrainingArguments
tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig
merge_model: bool
"""
parser = transformers.HfArgumentParser(
dataclass_types=(
configs.ModelArguments,
configs.DataArguments,
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
)
)

(
model_args,
data_args,
training_args,
lora_config,
prompt_tuning_config,
) = parser.parse_dict(job_config_dict, allow_extra_keys=True)

peft_method_parsed = job_config_dict.get("peft_method")

tune_config = None
merge_model = False
if peft_method_parsed == "lora":
tune_config = lora_config
merge_model = True
elif peft_method_parsed == "pt":
tune_config = prompt_tuning_config

logging.debug(
"Parameters used to launch training: \
model_args %s, data_args %s, training_args %s, tune_config %s",
model_args,
data_args,
training_args,
tune_config,
)

return model_args, data_args, training_args, tune_config, merge_model
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
# Register tests from `build` dir, removing `build` from norecursedirs default list,
# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs
norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch}
13 changes: 13 additions & 0 deletions tests/build/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
34 changes: 34 additions & 0 deletions tests/build/dummy_job_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"accelerate_launch_args": {
"use_fsdp": true,
"env": ["env1", "env2"],
"dynamo_use_dynamic": true,
"num_machines": 1,
"num_processes":2,
"main_process_port": 1234,
"fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": 1,
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_cpu_ram_efficient_loading": true,
"fsdp_sync_module_states": true,
"config_file": "fixtures/accelerate_fsdp_defaults.yaml"
},
"multi_gpu": true,
"model_name_or_path": "bigscience/bloom-560m",
"training_data_path": "data/twitter_complaints_small.json",
"output_dir": "bloom-twitter",
"num_train_epochs": 5.0,
"per_device_train_batch_size": 4,
"per_device_eval_batch_size": 4,
"gradient_accumulation_steps": 4,
"learning_rate": 0.03,
"weight_decay": 0.000001,
"lr_scheduler_type": "cosine",
"logging_steps": 1.0,
"packing": false,
"include_tokens_per_second": true,
"response_template": "### Label:",
"dataset_text_field": "output",
"use_flash_attn": false,
"tokenizer_name_or_path": "bigscience/bloom-560m"
}
80 changes: 80 additions & 0 deletions tests/build/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import copy
import json
import os

# Third Party
import pytest

# Local
from tuning.config.peft_config import LoraConfig, PromptTuningConfig
from build.utils import process_launch_training_args

HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
os.path.dirname(__file__), "dummy_job_config.json"
)


# Note: job_config dict gets modified during process_launch_training_args
@pytest.fixture(scope="session")
def job_config():
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
dummy_job_config_dict = json.load(f)
return dummy_job_config_dict


def test_process_launch_training_args(job_config):
job_config_copy = copy.deepcopy(job_config)
(
model_args,
data_args,
training_args,
tune_config,
merge_model,
) = process_launch_training_args(job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
assert training_args.output_dir == "bloom-twitter"
assert tune_config == None
assert merge_model == False


def test_process_launch_training_args_defaults(job_config):
job_config_defaults = copy.deepcopy(job_config)
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] == False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _ = process_launch_training_args(
job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn == False
assert training_args.save_strategy.value == "epoch"


def test_process_launch_training_args_peft_method(job_config):
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt)
assert type(tune_config) == PromptTuningConfig
assert merge_model == False

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora)
assert type(tune_config) == LoraConfig
assert merge_model == True
13 changes: 13 additions & 0 deletions tests/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18 changes: 18 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)
save_strategy: str = field(
tharapalanivel marked this conversation as resolved.
Show resolved Hide resolved
default="epoch",
metadata={
"help": "The checkpoint save strategy to adopt during training. \
Possible values are 'no'(no save is done during training), \
'epoch' (save is done at the end of each epoch), \
'steps' (save is done every `save_steps`)"
},
)
logging_strategy: str = field(
default="epoch",
metadata={
"help": "The logging strategy to adopt during training. \
Possible values are 'no'(no logging is done during training), \
'epoch' (logging is done at the end of each epoch), \
'steps' (logging is done every `logging_steps`)"
},
)
Loading