Skip to content

Commit

Permalink
Move accelerate launch args parsing (#107)
Browse files Browse the repository at this point in the history
* Move job config parsing to utils

Signed-off-by: Thara Palanivel <[email protected]>

* Add tests

Signed-off-by: Thara Palanivel <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sukriti Sharma <[email protected]>
Signed-off-by: tharapalanivel <[email protected]>

---------

Signed-off-by: Thara Palanivel <[email protected]>
Signed-off-by: tharapalanivel <[email protected]>
Co-authored-by: Sukriti Sharma <[email protected]>
  • Loading branch information
tharapalanivel and Ssukriti authored Apr 5, 2024
1 parent d6aed07 commit 2a0c4d3
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 109 deletions.
81 changes: 4 additions & 77 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,94 +18,21 @@
"""

# Standard
import json
import os
import base64
import pickle
import logging

# Third Party
from accelerate.commands.launch import launch_command_parser, launch_command


def txt_to_obj(txt):
base64_bytes = txt.encode("ascii")
message_bytes = base64.b64decode(base64_bytes)
try:
# If the bytes represent JSON string
return json.loads(message_bytes)
except UnicodeDecodeError:
# Otherwise the bytes are a pickled python dictionary
return pickle.loads(message_bytes)
from accelerate.commands.launch import launch_command
from build.utils import process_accelerate_launch_args, get_job_config


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)

json_configs = {}
json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")

if json_path:
with open(json_path, "r", encoding="utf-8") as f:
json_configs = json.load(f)

elif json_env_var:
json_configs = txt_to_obj(json_env_var)

parser = launch_command_parser()
# Map to determine which flags don't require a value to be set
actions_type_map = {
action.dest: type(action).__name__ for action in parser._actions
}

# Parse accelerate_launch_args
accelerate_launch_args = []
accelerate_config = json_configs.get("accelerate_launch_args", {})
if accelerate_config:
logging.info("Using accelerate_launch_args configs: %s", accelerate_config)
for key, val in accelerate_config.items():
if actions_type_map.get(key) == "_AppendAction":
for param_val in val:
accelerate_launch_args.extend([f"--{key}", str(param_val)])
elif (actions_type_map.get(key) == "_StoreTrueAction" and val) or (
actions_type_map.get(key) == "_StoreFalseAction" and not val
):
accelerate_launch_args.append(f"--{key}")
else:
accelerate_launch_args.append(f"--{key}")
# Only need to add key for params that aren't flags ie. --quiet
if actions_type_map.get(key) == "_StoreAction":
accelerate_launch_args.append(str(val))

num_processes = accelerate_config.get("num_processes")
if num_processes:
# if multi GPU setting and accelerate config_file not passed by user,
# use the default config for default set of parameters
if num_processes > 1 and not accelerate_config.get("config_file"):
# Add default FSDP config
fsdp_filepath = os.getenv(
"FSDP_DEFAULTS_FILE_PATH", "/app/accelerate_fsdp_defaults.yaml"
)
if os.path.exists(fsdp_filepath):
logging.info("Using accelerate config file: %s", fsdp_filepath)
accelerate_launch_args.extend(["--config_file", fsdp_filepath])

elif num_processes == 1:
logging.info("num_processes=1 so setting env var CUDA_VISIBLE_DEVICES=0")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
logging.warning(
"num_processes param was not passed in. Value from config file (if available) will \
be used or accelerate launch will determine number of processes automatically"
)

# Add training_script
accelerate_launch_args.append("/app/launch_training.py")
job_config = get_job_config()

logging.debug("accelerate_launch_args: %s", accelerate_launch_args)
args = parser.parse_args(args=accelerate_launch_args)
args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
launch_command(args)

Expand Down
35 changes: 4 additions & 31 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
"""

# Standard
import base64
import os
import pickle
import json
import tempfile
import shutil
import glob
Expand All @@ -32,18 +29,7 @@
# Local
from tuning import sft_trainer
from tuning.utils.merge_model_utils import create_merged_model
from build.utils import process_launch_training_args


def txt_to_obj(txt):
base64_bytes = txt.encode("ascii")
message_bytes = base64.b64decode(base64_bytes)
try:
# If the bytes represent JSON string
return json.loads(message_bytes)
except UnicodeDecodeError:
# Otherwise the bytes are a pickled python dictionary
return pickle.loads(message_bytes)
from build.utils import process_launch_training_args, get_job_config


def get_highest_checkpoint(dir_path):
Expand All @@ -67,30 +53,17 @@ def main():

logging.info("Initializing launch training script")

json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")

# accepts either path to JSON file or encoded string config
if json_path:
with open(json_path, "r", encoding="utf-8") as f:
job_config_dict = json.load(f)
elif json_env_var:
job_config_dict = txt_to_obj(json_env_var)
else:
raise ValueError(
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)
job_config = get_job_config()

logging.debug("Input params parsed: %s", job_config_dict)
logging.debug("Input params parsed: %s", job_config)

(
model_args,
data_args,
training_args,
tune_config,
merge_model,
) = process_launch_training_args(job_config_dict)
) = process_launch_training_args(job_config)

original_output_dir = training_args.output_dir
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
96 changes: 96 additions & 0 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,49 @@
# limitations under the License.

# Standard
import os
import json
import logging
import base64
import pickle

# Third Party
import transformers
from accelerate.commands.launch import launch_command_parser

# Local
from tuning.config import configs, peft_config


def txt_to_obj(txt):
base64_bytes = txt.encode("ascii")
message_bytes = base64.b64decode(base64_bytes)
try:
# If the bytes represent JSON string
return json.loads(message_bytes)
except UnicodeDecodeError:
# Otherwise the bytes are a pickled python dictionary
return pickle.loads(message_bytes)


def get_job_config():
json_path = os.getenv("SFT_TRAINER_CONFIG_JSON_PATH")
json_env_var = os.getenv("SFT_TRAINER_CONFIG_JSON_ENV_VAR")

# accepts either path to JSON file or encoded string config
if json_path:
with open(json_path, "r", encoding="utf-8") as f:
job_config_dict = json.load(f)
elif json_env_var:
job_config_dict = txt_to_obj(json_env_var)
else:
raise ValueError(
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)
return job_config_dict


def process_launch_training_args(job_config_dict):
"""Return parsed config for tuning to pass to SFT Trainer
Args:
Expand Down Expand Up @@ -71,3 +105,65 @@ def process_launch_training_args(job_config_dict):
)

return model_args, data_args, training_args, tune_config, merge_model


def process_accelerate_launch_args(job_config_dict):
"""Return parsed config for tuning to pass to SFT Trainer
Args:
job_config_dict: dict
Return:
args to pass to `accelerate launch`
"""
parser = launch_command_parser()
# Map to determine which flags don't require a value to be set
actions_type_map = {
action.dest: type(action).__name__ for action in parser._actions
}

# Parse accelerate_launch_args
accelerate_launch_args = []
accelerate_config = job_config_dict.get("accelerate_launch_args", {})
if accelerate_config:
logging.info("Using accelerate_launch_args configs: %s", accelerate_config)
for key, val in accelerate_config.items():
if actions_type_map.get(key) == "_AppendAction":
for param_val in val:
accelerate_launch_args.extend([f"--{key}", str(param_val)])
elif (actions_type_map.get(key) == "_StoreTrueAction" and val) or (
actions_type_map.get(key) == "_StoreFalseAction" and not val
):
accelerate_launch_args.append(f"--{key}")
else:
accelerate_launch_args.append(f"--{key}")
# Only need to add key for params that aren't flags ie. --quiet
if actions_type_map.get(key) == "_StoreAction":
accelerate_launch_args.append(str(val))

num_processes = accelerate_config.get("num_processes")
if num_processes:
# if multi GPU setting and accelerate config_file not passed by user,
# use the default config for default set of parameters
if num_processes > 1 and not accelerate_config.get("config_file"):
# Add default FSDP config
fsdp_filepath = os.getenv(
"FSDP_DEFAULTS_FILE_PATH", "/app/accelerate_fsdp_defaults.yaml"
)
if os.path.exists(fsdp_filepath):
logging.info("Using accelerate config file: %s", fsdp_filepath)
accelerate_launch_args.extend(["--config_file", fsdp_filepath])

elif num_processes == 1:
logging.info("num_processes=1 so setting env var CUDA_VISIBLE_DEVICES=0")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
logging.warning(
"num_processes param was not passed in. Value from config file (if available) will \
be used or accelerate launch will determine number of processes automatically"
)

# Add training_script
accelerate_launch_args.append("/app/launch_training.py")

logging.debug("accelerate_launch_args: %s", accelerate_launch_args)
args = parser.parse_args(args=accelerate_launch_args)
return args
37 changes: 36 additions & 1 deletion tests/build/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import copy
import json
import os
from unittest.mock import patch

# Third Party
import pytest

# Local
from tuning.config.peft_config import LoraConfig, PromptTuningConfig
from build.utils import process_launch_training_args
from build.utils import process_launch_training_args, process_accelerate_launch_args

HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
os.path.dirname(__file__), "dummy_job_config.json"
Expand Down Expand Up @@ -78,3 +79,37 @@ def test_process_launch_training_args_peft_method(job_config):
_, _, _, tune_config, merge_model = process_launch_training_args(job_config_lora)
assert type(tune_config) == LoraConfig
assert merge_model == True


def test_process_accelerate_launch_args(job_config):
job_config_copy = copy.deepcopy(job_config)
args = process_accelerate_launch_args(job_config_copy)
assert args.config_file == "fixtures/accelerate_fsdp_defaults.yaml"
assert args.use_fsdp == True
assert args.tpu_use_cluster == False


@patch("os.path.exists")
def test_process_accelerate_launch_custom_fsdp(patch_path_exists):
patch_path_exists.return_value = True

dummy_fsdp_path = "dummy_fsdp_config.yaml"

# When user passes custom fsdp config file, use custom config and accelerate
# launch will use `num_processes` from config
temp_job_config = {"accelerate_launch_args": {"config_file": dummy_fsdp_path}}
args = process_accelerate_launch_args(temp_job_config)
assert args.config_file == dummy_fsdp_path
assert args.num_processes == None

# When user passes custom fsdp config file and also `num_processes` as a param, use custom config and
# overwrite num_processes from config with param
temp_job_config = {
"accelerate_launch_args": {
"config_file": dummy_fsdp_path,
"num_processes": 3,
}
}
args = process_accelerate_launch_args(temp_job_config)
assert args.config_file == dummy_fsdp_path
assert args.num_processes == 3

0 comments on commit 2a0c4d3

Please sign in to comment.