Skip to content

Commit

Permalink
fix: Add post processing flag so post processing is only done for vLLM
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj committed Sep 18, 2024
1 parent bcc17b1 commit 57cadc3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@ The `fms_acceleration.cli` can do more to search for all available configs, plug
## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

If you are trying to run LoRA inference on vLLM, set the `--post_process_vllm` flag to `True`.

### Running a single example
If you want to run a single example through a model, you can pass it with the `--text` flag.

Expand Down
7 changes: 4 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -347,7 +348,7 @@ def test_parse_arguments_defaults(job_config):
assert "torch_dtype" not in job_config_defaults
assert job_config_defaults["use_flash_attn"] is False
assert "save_strategy" not in job_config_defaults
model_args, _, training_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
model_args, _, training_args, _, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_defaults
)
assert str(model_args.torch_dtype) == "torch.bfloat16"
Expand All @@ -359,14 +360,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down
29 changes: 16 additions & 13 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@
from trl import SFTConfig, SFTTrainer
import transformers

# First Party
from build.utils import get_highest_checkpoint

# Local
from tuning.config import configs, peft_config
from tuning.config.acceleration_configs import (
Expand Down Expand Up @@ -440,6 +437,13 @@ def get_parser():
choices=["pt", "lora", None, "none"],
default="none",
)
parser.add_argument(
"--post_process_vllm",
type=bool,
default=False,
help="Bool to indicate if post processing of LoRA adapters for vLLM \
is required.",
)
parser.add_argument(
"--exp_metadata",
type=str,
Expand Down Expand Up @@ -496,6 +500,7 @@ def parse_arguments(parser, json_config=None):
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
post_process_vllm = json_config.get("post_process_vllm")
else:
(
model_args,
Expand All @@ -514,6 +519,7 @@ def parse_arguments(parser, json_config=None):

peft_method = additional.peft_method
exp_metadata = additional.exp_metadata
post_process_vllm = additional.post_process_vllm

if peft_method == "lora":
tune_config = lora_config
Expand All @@ -533,6 +539,7 @@ def parse_arguments(parser, json_config=None):
quantized_lora_config,
fusedops_kernels_config,
exp_metadata,
post_process_vllm,
)


Expand All @@ -553,6 +560,7 @@ def main():
quantized_lora_config,
fusedops_kernels_config,
exp_metadata,
post_process_vllm,
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
Expand Down Expand Up @@ -656,17 +664,12 @@ def main():
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# post process lora
if isinstance(tune_config, peft_config.LoraConfig):
if post_process_vllm and isinstance(tune_config, peft_config.LoraConfig):
try:
checkpoint_dir = job_config.get("save_model_dir")
if not checkpoint_dir:
checkpoint_dir = os.path.join(
training_args.output_dir,
get_highest_checkpoint(training_args.output_dir),
)
print(training_args)
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir)
checkpoint_dir = training_args.save_model_dir
if checkpoint_dir:
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
Expand Down

0 comments on commit 57cadc3

Please sign in to comment.