Skip to content

Commit

Permalink
Added support for running official HF baseline FSDP-QLoRA benchmark (#16
Browse files Browse the repository at this point in the history
)

* new baseline scenario

* rename variables

* added warning when plugin allows SFTTrainer to handle PEFT on single device
  • Loading branch information
achew010 authored May 21, 2024
1 parent 1c790ed commit d510ceb
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 2 deletions.
4 changes: 4 additions & 0 deletions plugins/accelerated-peft/configs/bnb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ peft:
# bitsandbytes:
bitsandbytes:
quant_type: nf4

# If True, then no get_peft_model and prepare_model_for_kbit_training
# will be called.
no_peft_model: False
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(self, configurations: Dict[str, Dict]):
self._quant_type = self._check_config_and_maybe_check_values(
key="peft.quantization.bitsandbytes.quant_type", values=["fp4", "nf4"]
)
self._no_peft_model = self._check_config_and_maybe_check_values(
key="peft.quantization.bitsandbytes.no_peft_model", values=[True, False]
)

def model_loader(self, model_name: str, **kwargs):

Expand All @@ -121,6 +124,16 @@ def model_loader(self, model_name: str, **kwargs):
"If running in FSDP, this is probably because accelerate is not used. "
"This will most probably result in error."
)
elif (
world_size == 1
and self._no_peft_model == True
):
warnings.warn(
"""Running on single device and setting plugin config `no_peft_model` as `True`
PEFT preparation will be managed by SFTTrainer and will cause a slowdown in training speed
due to extraneous dtype casting when SFTTrainer prepares the model using
https://github.com/huggingface/trl/blob/e90e8d91d2265e484f229c45a5eb8982f94a2936/trl/trainer/sft_trainer.py#L210"""
)

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand All @@ -147,7 +160,8 @@ def requires_custom_loading(self):

@property
def requires_agumentation(self):
return True
# will skip the augmentation if _no_peft_model == True
return not self._no_peft_model

def augmentation(
self,
Expand Down
14 changes: 14 additions & 0 deletions plugins/accelerated-peft/tests/test_peft_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ def test_configure_bnb_plugin():
assert framework.requires_agumentation
assert len(framework.get_callbacks_and_ready_for_train()) == 0

# test no_peft_model is true skips plugin.augmentation
for key, correct_value in [
("peft.quantization.bitsandbytes.no_peft_model", True),
("peft.quantization.bitsandbytes.no_peft_model", False),
]:
with instantiate_framework(
update_configuration_contents(
read_configuration(CONFIG_PATH_BNB), key, correct_value
),
require_packages_check=False,
):
# check flags and callbacks
assert (not correct_value)==framework.requires_agumentation

# attempt to activate plugin with configuration pointing to wrong path
# - raise with message that no plugins can be configured
with pytest.raises(ValueError) as e:
Expand Down
7 changes: 6 additions & 1 deletion sample-configurations/CONTENTS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ framework_configs:
- shortname: accelerated-peft-bnb
plugins:
- accelerated-peft
filename: accelerated-peft-bnb-nf4-sample-configuration.yaml
filename: accelerated-peft-bnb-nf4-sample-configuration.yaml

- shortname: baseline-peft-bnb
plugins:
- accelerated-peft
filename: baseline-peft-bnb-nf4-sample-configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ plugins:
# bitsandbytes:
bitsandbytes:
quant_type: nf4

# If True, then no get_peft_model and prepare_model_for_kbit_training
# will be called.
no_peft_model: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# FMS Acceleration Plugin Configuration.
#
# Each stanza incorporates various configurations for
# different fine-tuning / training tasks.
plugins:
# PEFT-related acceleration
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

# For loading BitsAndBytes quantized layers
# to serve as 4bit base-weights for LoRA PEFT-tuning.
# NOTE: currently AutoGPTQ is not properly integrated into huggingface /
# bitsandbytes, thus recommended quant_type to be either "nf4"
# or "fp4".
# bitsandbytes:
bitsandbytes:
quant_type: nf4

# If True, then no get_peft_model and prepare_model_for_kbit_training
# will be called.
no_peft_model: true
17 changes: 17 additions & 0 deletions scripts/benchmarks/scenarios.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ scenarios:
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'

- name: baseline-peft-bnb
framework_config:
- baseline-peft-bnb
arguments:
fp16: True
learning_rate: 2e-4
torch_dtype: float16
peft_method: lora
r: 16
lora_alpha: 16
lora_dropout: 0.0
target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
- 'NousResearch/Llama-2-70b-hf'

- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb
Expand Down
9 changes: 9 additions & 0 deletions scripts/generate_sample_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,21 @@ def read_configuration(path: str) -> Dict:
# specified key path, with the value.
KEY_AUTO_GPTQ = "auto_gptq"
KEY_BNB_NF4 = "bnb-nf4"
KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4"

CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
KEY_BNB_NF4: (
"plugins/accelerated-peft/configs/bnb.yaml",
[("peft.quantization.bitsandbytes.quant_type", "nf4")],
),
KEY_BNB_NF4_BASELINE: (
"plugins/accelerated-peft/configs/bnb.yaml",
[
("peft.quantization.bitsandbytes.quant_type", "nf4"),
("peft.quantization.bitsandbytes.no_peft_model", True),
],
),
}

# list of (tag, combi) tuples
Expand All @@ -158,6 +166,7 @@ def read_configuration(path: str) -> Dict:
COMBINATIONS = [
("accelerated-peft-autogptq", (KEY_AUTO_GPTQ,)),
("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)),
("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)),
]


Expand Down

0 comments on commit d510ceb

Please sign in to comment.