-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Setting default values in training job config (#104)
* Allow for default params to be set Signed-off-by: Thara Palanivel <[email protected]> * Add tests Signed-off-by: Thara Palanivel <[email protected]> * Simplifying default params logic Signed-off-by: Thara Palanivel <[email protected]> * Setting use_flash_attn default Signed-off-by: Thara Palanivel <[email protected]> * Formatting Signed-off-by: Thara Palanivel <[email protected]> * Address review comments Signed-off-by: Thara Palanivel <[email protected]> * Moving tests Signed-off-by: Thara Palanivel <[email protected]> * Address review comments Signed-off-by: Thara Palanivel <[email protected]> * Address review comment Signed-off-by: Thara Palanivel <[email protected]> * Fix merge conflicts Signed-off-by: Thara Palanivel <[email protected]> --------- Signed-off-by: Thara Palanivel <[email protected]>
- Loading branch information
1 parent
2df20ba
commit 3785363
Showing
8 changed files
with
244 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Standard | ||
import logging | ||
|
||
# Third Party | ||
import transformers | ||
|
||
# Local | ||
from tuning.config import configs, peft_config | ||
|
||
|
||
def process_launch_training_args(job_config_dict): | ||
"""Return parsed config for tuning to pass to SFT Trainer | ||
Args: | ||
job_config_dict: dict | ||
Return: | ||
model_args: configs.ModelArguments | ||
data_args: configs.DataArguments | ||
training_args: configs.TrainingArguments | ||
tune_config: peft_config.LoraConfig | peft_config.PromptTuningConfig | ||
merge_model: bool | ||
""" | ||
parser = transformers.HfArgumentParser( | ||
dataclass_types=( | ||
configs.ModelArguments, | ||
configs.DataArguments, | ||
configs.TrainingArguments, | ||
peft_config.LoraConfig, | ||
peft_config.PromptTuningConfig, | ||
) | ||
) | ||
|
||
( | ||
model_args, | ||
data_args, | ||
training_args, | ||
lora_config, | ||
prompt_tuning_config, | ||
) = parser.parse_dict(job_config_dict, allow_extra_keys=True) | ||
|
||
peft_method_parsed = job_config_dict.get("peft_method") | ||
|
||
tune_config = None | ||
merge_model = False | ||
if peft_method_parsed == "lora": | ||
tune_config = lora_config | ||
merge_model = True | ||
elif peft_method_parsed == "pt": | ||
tune_config = prompt_tuning_config | ||
|
||
logging.info( | ||
"Parameters used to launch training: \ | ||
model_args %s, data_args %s, training_args %s, tune_config %s", | ||
model_args, | ||
data_args, | ||
training_args, | ||
tune_config, | ||
) | ||
|
||
return model_args, data_args, training_args, tune_config, merge_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[pytest] | ||
# Register tests from `build` dir, removing `build` from norecursedirs default list, | ||
# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs | ||
norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"accelerate_launch_args": { | ||
"use_fsdp": true, | ||
"env": ["env1", "env2"], | ||
"dynamo_use_dynamic": true, | ||
"num_machines": 1, | ||
"num_processes":2, | ||
"main_process_port": 1234, | ||
"fsdp_backward_prefetch_policy": "TRANSFORMER_BASED_WRAP", | ||
"fsdp_sharding_strategy": 1, | ||
"fsdp_state_dict_type": "FULL_STATE_DICT", | ||
"fsdp_cpu_ram_efficient_loading": true, | ||
"fsdp_sync_module_states": true, | ||
"config_file": "fixtures/accelerate_fsdp_defaults.yaml" | ||
}, | ||
"multi_gpu": true, | ||
"model_name_or_path": "bigscience/bloom-560m", | ||
"training_data_path": "data/twitter_complaints_small.json", | ||
"output_dir": "bloom-twitter", | ||
"num_train_epochs": 5.0, | ||
"per_device_train_batch_size": 4, | ||
"per_device_eval_batch_size": 4, | ||
"gradient_accumulation_steps": 4, | ||
"learning_rate": 0.03, | ||
"weight_decay": 0.000001, | ||
"lr_scheduler_type": "cosine", | ||
"logging_steps": 1.0, | ||
"packing": false, | ||
"include_tokens_per_second": true, | ||
"response_template": "### Label:", | ||
"dataset_text_field": "output", | ||
"use_flash_attn": false, | ||
"tokenizer_name_or_path": "bigscience/bloom-560m" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Standard | ||
import copy | ||
import json | ||
import os | ||
|
||
# Third Party | ||
import pytest | ||
|
||
# Local | ||
from tuning.config.peft_config import LoraConfig, PromptTuningConfig | ||
from build.utils import process_launch_training_args | ||
|
||
HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( | ||
os.path.dirname(__file__), "dummy_job_config.json" | ||
) | ||
|
||
|
||
# Note: job_config dict gets modified during process_launch_training_args | ||
@pytest.fixture(scope="session") | ||
def job_config(): | ||
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: | ||
dummy_job_config_dict = json.load(f) | ||
return dummy_job_config_dict | ||
|
||
|
||
def test_process_launch_training_args(job_config): | ||
job_config_copy = copy.deepcopy(job_config) | ||
( | ||
model_args, | ||
data_args, | ||
training_args, | ||
tune_config, | ||
merge_model, | ||
) = process_launch_training_args(job_config_copy) | ||
assert str(model_args.torch_dtype) == "torch.bfloat16" | ||
assert data_args.dataset_text_field == "output" | ||
assert training_args.output_dir == "bloom-twitter" | ||
assert tune_config == None | ||
assert merge_model == False | ||
|
||
|
||
def test_process_launch_training_args_defaults(job_config): | ||
job_config_defaults = copy.deepcopy(job_config) | ||
assert "torch_dtype" not in job_config_defaults | ||
assert job_config_defaults["use_flash_attn"] == False | ||
assert "save_strategy" not in job_config_defaults | ||
model_args, _, training_args, _, _ = process_launch_training_args( | ||
job_config_defaults | ||
) | ||
assert str(model_args.torch_dtype) == "torch.bfloat16" | ||
assert model_args.use_flash_attn == False | ||
assert training_args.save_strategy.value == "epoch" | ||
|
||
|
||
def test_process_launch_training_args_peft_method(job_config): | ||
job_config_pt = copy.deepcopy(job_config) | ||
job_config_pt["peft_method"] = "pt" | ||
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_pt) | ||
assert type(tune_config) == PromptTuningConfig | ||
assert merge_model == False | ||
|
||
job_config_lora = copy.deepcopy(job_config) | ||
job_config_lora["peft_method"] = "lora" | ||
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora) | ||
assert type(tune_config) == LoraConfig | ||
assert merge_model == True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters