Skip to content

Commit

Permalink
feat: Refactor post-processing of adapters (#345)
Browse files Browse the repository at this point in the history
* refactor saving tokens metadata

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* remove extra check

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* post processing script

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* post processing script

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix: unit test args

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* undo post_process_vLLm flag

Signed-off-by: Sukriti-Sharma4 <[email protected]>

---------

Signed-off-by: Sukriti-Sharma4 <[email protected]>
  • Loading branch information
Ssukriti authored Sep 23, 2024
1 parent c8d8f98 commit 2151225
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 36 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,6 @@ The `fms_acceleration.cli` can do more to search for all available configs, plug
## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

If you are trying to run LoRA inference on vLLM, set the `--post_process_vllm` flag to `True`.

### Running a single example
If you want to run a single example through a model, you can pass it with the `--text` flag.

Expand Down
53 changes: 53 additions & 0 deletions scripts/post_process_adapters_vLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Standard
import argparse
import json
import os

# Local
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens


### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Post processes adapters due to addition of new tokens, as needed by vLLM"
)
parser.add_argument(
"--model_path",
help="Path to tuned model containing either one or multiple checkpoints \
Path should have file added_tokens_info.json produced by tuning \
Hint: This will be either output_dir or save_model_dir arguments while tuning \
If multiple checkpoints are present, each checkpoint folder name \
should begin with 'checkpoint-'",
required=True,
)
parser.add_argument(
"--output_model_path",
help="Output directory where post-processed artifacts will be stored. \
If not provided, artifacts will be modified in place",
default=None,
)
args = parser.parse_args()

if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")):
with open(
os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.loads(json_data)
num_added_tokens = added_tokens_info["num_added_tokens"]
else:
print("file added_tokens_info.json not in model_path. Cannot post-processes")

if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")):
post_process_vLLM_adapters_new_tokens(
args.model_path, args.output_model_path, num_added_tokens
)
# if multiple checkpoints in directory, process each checkpoint
for _, dirs, _ in os.walk(args.model_path, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(args.model_path, name),
os.path.join(args.output_model_path, name),
num_added_tokens,
)
2 changes: 1 addition & 1 deletion tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_lora_save_model_dir_separate_dirs():
_validate_termination_files_when_tuning_succeeds(output_dir)
_validate_training_output(save_model_dir, "lora")

assert len(os.listdir(output_dir)) == 3
# purpose here is to see if only one checkpoint is saved
checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*"))
assert len(checkpoints) == 1

Expand Down
6 changes: 2 additions & 4 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -361,7 +360,6 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -372,14 +370,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
14 changes: 11 additions & 3 deletions tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ def tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
multiple_of: int = 1,
):
"""Resize tokenizer and embedding."""
) -> dict:
"""Resize tokenizer and embedding.
Args:
special_tokens_dict: Dict containing special tokens to be added.
tokenizer: transformers.PreTrainedTokenizer.
model: transformers.PreTrainedModel
multiple_of: int , embeddings are resized to multiple of this.
Return:
dict: Metadata on number of added tokens
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
num_new_tokens = num_new_tokens + embedding_size - len(tokenizer)
Expand All @@ -44,4 +52,4 @@ def tokenizer_and_embedding_resize(

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
return num_new_tokens
return {"num_new_tokens": num_new_tokens, "new_embedding_size": embedding_size}
51 changes: 26 additions & 25 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.merge_model_utils import post_process_vLLM_adapters_new_tokens
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
Expand Down Expand Up @@ -291,7 +290,7 @@ def train(

# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize(
added_tokens_dict = tokenizer_data_utils.tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -411,8 +410,9 @@ def train(
)

trainer.train(resume_from_checkpoint)

return trainer, num_added_tokens
additional_metadata = {}
additional_metadata["added_tokens_info"] = added_tokens_dict
return trainer, additional_metadata


def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Expand Down Expand Up @@ -463,13 +463,7 @@ def get_parser():
choices=["pt", "lora", None, "none"],
default="none",
)
parser.add_argument(
"--post_process_vllm",
type=bool,
default=False,
help="Bool to indicate if post processing of LoRA adapters for vLLM \
is required.",
)

parser.add_argument(
"--exp_metadata",
type=str,
Expand Down Expand Up @@ -529,7 +523,6 @@ def parse_arguments(parser, json_config=None):
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
post_process_vllm = json_config.get("post_process_vllm")
else:
(
model_args,
Expand All @@ -549,7 +542,6 @@ def parse_arguments(parser, json_config=None):

peft_method = additional.peft_method
exp_metadata = additional.exp_metadata
post_process_vllm = additional.post_process_vllm

if peft_method == "lora":
tune_config = lora_config
Expand All @@ -570,7 +562,6 @@ def parse_arguments(parser, json_config=None):
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
post_process_vllm,
)


Expand All @@ -592,7 +583,6 @@ def main():
fusedops_kernels_config,
attention_and_distributed_packing_config,
exp_metadata,
post_process_vllm,
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
Expand Down Expand Up @@ -644,7 +634,7 @@ def main():
combined_tracker_configs.aim_config = aim_config

try:
trainer, num_added_tokens = train(
trainer, additional_train_info = train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
Expand Down Expand Up @@ -697,19 +687,30 @@ def main():
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# post process lora
if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig):
if isinstance(tune_config, peft_config.LoraConfig):
try:
checkpoint_dir = training_args.save_model_dir
if checkpoint_dir:
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(
path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens
)
if training_args.save_model_dir:
# Write number of added tokens to artifacts
with open(
os.path.join(
training_args.save_model_dir, "added_tokens_info.json"
),
"w",
encoding="utf-8",
) as f:
json.dump(additional_train_info["added_tokens_info"], f)
if training_args.output_dir:
# Write number of added tokens to artifacts
with open(
os.path.join(training_args.output_dir, "added_tokens_info.json"),
"w",
encoding="utf-8",
) as f:
json.dump(additional_train_info["added_tokens_info"], f)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered while lora post-processing model: {e}"
f"Exception encountered when saving metadata with model artifacts: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

Expand Down
14 changes: 13 additions & 1 deletion tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,19 @@ def post_process_vLLM_adapters_new_tokens(
modified_checkpoint_path: str = None,
num_added_tokens: int = 0,
):
# if not set, original checkpoint will be modified
"""Post process adapters to allow inferencing on vLLM.
vLLM needs new token embedding weights added during tuning to be moved \
to a new file new_embeddings.safetensors . \
This function copies the embeddings weights for the added tokens from \
adapters.safetnsors to new_embeddings.safetensors.
Args:
path_to_checkpoint: Path to folder containing adapters.safetensors.
modified_checkpoint_path: Output path where to save modified artifacts \
after post-processing. If not provided, artifacts will be processed \
in place in same folder.
num_added_tokens: int. Number of tokens that were added during tuning.
"""
# if not set, original checkpoint will be modified in place
if not modified_checkpoint_path:
modified_checkpoint_path = path_to_checkpoint

Expand Down

0 comments on commit 2151225

Please sign in to comment.