Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Refactor post-processing of adapters #345

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,6 @@ The `fms_acceleration.cli` can do more to search for all available configs, plug
## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

If you are trying to run LoRA inference on vLLM, set the `--post_process_vllm` flag to `True`.

### Running a single example
If you want to run a single example through a model, you can pass it with the `--text` flag.

Expand Down
53 changes: 53 additions & 0 deletions scripts/post_process_adapters_vLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Standard
import argparse
import json
import os

# Local
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens


### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Post processes adapters due to addition of new tokens, as needed by vLLM"
)
parser.add_argument(
"--model_path",
help="Path to tuned model containing either one or multiple checkpoints \
Path should have file added_tokens_info.json produced by tuning \
Hint: This will be either output_dir or save_model_dir arguments while tuning \
If multiple checkpoints are present, each checkpoint folder name \
should begin with 'checkpoint-'",
required=True,
)
parser.add_argument(
"--output_model_path",
help="Output directory where post-processed artifacts will be stored. \
If not provided, artifacts will be modified in place",
default=None,
)
args = parser.parse_args()

if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")):
with open(
os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.loads(json_data)
num_added_tokens = added_tokens_info["num_added_tokens"]
else:
print("file added_tokens_info.json not in model_path. Cannot post-processes")

if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")):
post_process_vLLM_adapters_new_tokens(
args.model_path, args.output_model_path, num_added_tokens
)
# if multiple checkpoints in directory, process each checkpoint
for _, dirs, _ in os.walk(args.model_path, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(args.model_path, name),
os.path.join(args.output_model_path, name),
num_added_tokens,
)
2 changes: 1 addition & 1 deletion tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_lora_save_model_dir_separate_dirs():
_validate_termination_files_when_tuning_succeeds(output_dir)
_validate_training_output(save_model_dir, "lora")

assert len(os.listdir(output_dir)) == 3
# purpose here is to see if only one checkpoint is saved
checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*"))
assert len(checkpoints) == 1

Expand Down
6 changes: 2 additions & 4 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -361,7 +360,6 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -372,14 +370,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
14 changes: 11 additions & 3 deletions tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ def tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
multiple_of: int = 1,
):
"""Resize tokenizer and embedding."""
) -> dict:
"""Resize tokenizer and embedding.
Args:
special_tokens_dict: Dict containing special tokens to be added.
tokenizer: transformers.PreTrainedTokenizer.
model: transformers.PreTrainedModel
multiple_of: int , embeddings are resized to multiple of this.
Return:
dict: Metadata on number of added tokens
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
num_new_tokens = num_new_tokens + embedding_size - len(tokenizer)
Expand All @@ -44,4 +52,4 @@ def tokenizer_and_embedding_resize(

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
return num_new_tokens
return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size}
51 changes: 26 additions & 25 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
Expand Down Expand Up @@ -291,7 +290,7 @@ def train(

# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize(
added_tokens_dict = tokenizer_data_utils.tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -411,8 +410,9 @@ def train(
)

trainer.train(resume_from_checkpoint)

return trainer, num_added_tokens
additional_metadata = {}
additional_metadata["added_tokens_info"] = added_tokens_dict
return trainer, additional_metadata


def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Expand Down Expand Up @@ -463,13 +463,7 @@ def get_parser():
choices=["pt", "lora", None, "none"],
default="none",
)
parser.add_argument(
"--post_process_vllm",
type=bool,
default=False,
help="Bool to indicate if post processing of LoRA adapters for vLLM \
is required.",
)

parser.add_argument(
"--exp_metadata",
type=str,
Expand Down Expand Up @@ -529,7 +523,6 @@ def parse_arguments(parser, json_config=None):
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
post_process_vllm = json_config.get("post_process_vllm")
else:
(
model_args,
Expand All @@ -549,7 +542,6 @@ def parse_arguments(parser, json_config=None):

peft_method = additional.peft_method
exp_metadata = additional.exp_metadata
post_process_vllm = additional.post_process_vllm

if peft_method == "lora":
tune_config = lora_config
Expand All @@ -570,7 +562,6 @@ def parse_arguments(parser, json_config=None):
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
post_process_vllm,
)


Expand All @@ -592,7 +583,6 @@ def main():
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
post_process_vllm,
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
Expand Down Expand Up @@ -644,7 +634,7 @@ def main():
combined_tracker_configs.aim_config = aim_config

try:
trainer, num_added_tokens = train(
trainer, additional_train_info = train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
Expand Down Expand Up @@ -697,19 +687,30 @@ def main():
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# post process lora
if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig):
if isinstance(tune_config, peft_config.LoraConfig):
try:
checkpoint_dir = training_args.save_model_dir
if checkpoint_dir:
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(
path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens
)
if training_args.save_model_dir:
# Write number of added tokens to artifacts
with open(
os.path.join(
training_args.save_model_dir, "added_tokens_info.json"
),
"w",
encoding="utf-8",
) as f:
json.dump(additional_train_info["added_tokens_info"], f)
if training_args.output_dir:
# Write number of added tokens to artifacts
with open(
os.path.join(training_args.output_dir, "added_tokens_info.json"),
"w",
encoding="utf-8",
) as f:
json.dump(additional_train_info["added_tokens_info"], f)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered while lora post-processing model: {e}"
f"Exception encountered when saving metadata with model artifacts: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

Expand Down
14 changes: 13 additions & 1 deletion tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,19 @@ def post_process_vLLM_adapters_new_tokens(
modified_checkpoint_path: str = None,
num_added_tokens: int = 0,
):
# if not set, original checkpoint will be modified
"""Post process adapters to allow inferencing on vLLM.
vLLM needs new token embedding weights added during tuning to be moved \
to a new file new_embeddings.safetensors . \
This function copies the embeddings weights for the added tokens from \
adapters.safetnsors to new_embeddings.safetensors.
Args:
path_to_checkpoint: Path to folder containing adapters.safetensors.
modified_checkpoint_path: Output path where to save modified artifacts \
after post-processing. If not provided, artifacts will be processed \
in place in same folder.
num_added_tokens: int. Number of tokens that were added during tuning.
"""
# if not set, original checkpoint will be modified in place
if not modified_checkpoint_path:
modified_checkpoint_path = path_to_checkpoint

Expand Down
Loading