Skip to content

Commit

Permalink
fix: utilities to post process checkpoint for LoRA (#338)
Browse files Browse the repository at this point in the history
* utilities to post process checkpoint for LoRA

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* improve code comments

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* Add unit test and fix some lint errors

Signed-off-by: Angel Luu <[email protected]>

* lint: fix more fmt errors

Signed-off-by: Angel Luu <[email protected]>

* feat: Add post_process_vLLM_adapters_new_tokens function to main

Signed-off-by: Will Johnson <[email protected]>

* fmt

Signed-off-by: Will Johnson <[email protected]>

* fix: Add post processing flag so post processing is only done for vLLM

Signed-off-by: Will Johnson <[email protected]>

* fix: get num_added_tokens from resize function (#344)

* get num_added_tokens

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* remove extra code

Signed-off-by: Sukriti-Sharma4 <[email protected]>

---------

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* Ran fmt and also removed unneccessary files from test artifact

Signed-off-by: Angel Luu <[email protected]>

* fix: unit tests

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix: Adding tokens in special_tokens_dict

Signed-off-by: Abhishek <[email protected]>

* fix: Add additional arg to tests to reflect new flag post_process_vllm

Signed-off-by: Will Johnson <[email protected]>

* fmt

Signed-off-by: Will Johnson <[email protected]>

* feat: Refactor post-processing of adapters (#345)

* refactor saving tokens metadata

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* remove extra check

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* post processing script

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* post processing script

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix: unit test args

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* undo post_process_vLLm flag

Signed-off-by: Sukriti-Sharma4 <[email protected]>

---------

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* add test for LoRA tuning from main

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix formatting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* correcting post processing script

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:post-process in place

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* update documentation for post-processing

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:formatting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:linting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* more warnings /exceptions in script

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* check for no tokens added

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:linting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* additional unit test

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* add more tests

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:tokenizer test

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:linting and docstrings

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:return type of trainer

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* test: enable tests and fix copytree

Signed-off-by: Anh Uong <[email protected]>

* use copy function from build

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* fix:linting and formatting

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* make build a module

Signed-off-by: Sukriti-Sharma4 <[email protected]>

* add back old copy function

Signed-off-by: Sukriti-Sharma4 <[email protected]>

---------

Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Angel Luu <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Abhishek <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Co-authored-by: Angel Luu <[email protected]>
Co-authored-by: Will Johnson <[email protected]>
Co-authored-by: Abhishek <[email protected]>
Co-authored-by: Anh Uong <[email protected]>
  • Loading branch information
5 people authored Sep 25, 2024
1 parent c0c4355 commit 7714dfc
Show file tree
Hide file tree
Showing 16 changed files with 97,521 additions and 27 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,35 @@ Example 3:

</details>

#### Post-processing needed for inference on VLLM

In order to run inference of LoRA adapters on vLLM, any new token embeddings added while tuning needs to be moved out of 'adapters.safetensors' to a new file 'new_embeddings.safetensors'. The 'adapters.safetensors' should only have LoRA weights and should not have modified embedding vectors. This is a requirement to support vLLM's paradigm that one base model can serve multiple adapters. New token embedding vectors are appended to the embedding matrix read from the base model by vLLM.

To do this postprocessing, the tuning script sft_trainer.py will generate a file 'added_tokens_info.json' with model artifacts. After tuning, you can run script 'post_process_adapters_vLLM.py' :

```bash
# model_path: Path to saved model artifacts which has file 'added_tokens_info.json'
# output_model_path: Optional. If you want to store modified \
# artifacts in a different directory rather than modify in-place.
python scripts/post_process_adapters_vLLM.py \
--model_path "/testing/tuning/output/post-process-LoRA-saved" \
--output_model_path "/testing/tuning/output/post-process-LoRA-modified"
```

<details>
<summary> Alternatively, if using SDK :</summary>

```bash
# function in tuning/utils/merge_model_utils.py
post_process_vLLM_adapters_new_tokens(
path_to_checkpoint="/testing/tuning/output/post-process-LoRA-saved",
modified_checkpoint_path=None,
num_added_tokens=1,
)
# where num_added_tokens is returned by sft_trainer.train()
```
</details>

_________________________


Expand Down
6 changes: 5 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
import shutil


def copy_checkpoint(source, destination):
def copy_checkpoint(source, destination, exclude_files: list[str] = None):
if not os.path.exists(destination):
os.makedirs(destination)
shutil.copystat(source, destination)
# Have a list of directory objects, now iterate over them.
if exclude_files is None:
exclude_files = []
for item in os.listdir(source):
if item in exclude_files:
continue
source_file = os.path.join(source, item)
destination_file = os.path.join(destination, item)
if os.path.isdir(source_file):
Expand Down
94 changes: 94 additions & 0 deletions scripts/post_process_adapters_vLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
""" Script to post-process tuned LoRA adapters for inference on vLLM.
vLLM requires that any token embeddings added while tuning be moved to a new file \
called new_embeddings.safetensors. \
See the description in utility function \
/tuning/utils/merge_model_utils/post_process_vLLM_adapters_new_tokens for more details.
This script takes a path to tuned model artifacts containing adapters \
(or checkpoints with adapters) and the file 'added_tokens_info.json' produced while tuning. \
It will perform the post-processing as needed for inferencing on vLLM.
"""
# Standard
import argparse
import json
import logging
import os
import sys

# Local
from tuning.utils.merge_model_utils import (
copy_files_to_directory,
post_process_vLLM_adapters_new_tokens,
)


### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Post processes LoRA adapters due to addition of new tokens, as needed by vLLM"
)
parser.add_argument(
"--model_path",
help="Path to tuned model containing either one or multiple checkpoints. \
Path should have file added_tokens_info.json produced by tuning. \
Hint: This will be either output_dir or save_model_dir arguments while tuning. \
If multiple checkpoints are present, each checkpoint folder name \
should begin with 'checkpoint-'",
required=True,
)
parser.add_argument(
"--output_model_path",
help="Output directory where post-processed artifacts will be stored. \
If not provided, artifacts will be modified in place",
default=None,
)
args = parser.parse_args()

if args.output_model_path is None:
output_model_path = args.model_path
else:
output_model_path = args.output_model_path
if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")):
with open(
os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
else:
raise ValueError(
"file added_tokens_info.json not in model_path. \
Cannot post-processes"
)
if num_added_tokens == 0:
logging.info("No new tokens added, hence post-processing not needed")
sys.exit(0)

found_adapters = 0
if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")):
found_adapters = 1
post_process_vLLM_adapters_new_tokens(
args.model_path, output_model_path, num_added_tokens
)
# if multiple checkpoints in directory, process each checkpoint
found_checkpoints = 0
for _, dirs, _ in os.walk(args.model_path, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(args.model_path, name),
os.path.join(output_model_path, name),
num_added_tokens,
)
found_checkpoints = 1
if found_checkpoints and output_model_path != args.model_path:
copy_files_to_directory(
args.model_path,
output_model_path,
exclude_files=["adapter_model.safetensors"],
)
if not found_adapters and not found_checkpoints:
logging.warning("No adapters were found to process in model path provided")


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions tests/artifacts/tuned_llama_with_added_tokens/adapter_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"alpha_pattern": {},
"auto_mapping": null,
"base_model_name_or_path": "Maykeye/TinyLLama-v0",
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layer_replication": null,
"layers_pattern": null,
"layers_to_transform": null,
"loftq_config": {},
"lora_alpha": 32,
"lora_dropout": 0.05,
"megatron_config": null,
"megatron_core": "megatron.core",
"modules_to_save": null,
"peft_type": "LORA",
"r": 8,
"rank_pattern": {},
"revision": null,
"target_modules": [
"v_proj",
"q_proj"
],
"task_type": "CAUSAL_LM",
"use_dora": false,
"use_rslora": false
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"<pad>": 32000
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Loading

0 comments on commit 7714dfc

Please sign in to comment.