diff --git a/plugins/accelerated-peft/configs/bnb.yaml b/plugins/accelerated-peft/configs/bnb.yaml index a29eef5e..ec5c3cfa 100644 --- a/plugins/accelerated-peft/configs/bnb.yaml +++ b/plugins/accelerated-peft/configs/bnb.yaml @@ -14,3 +14,7 @@ peft: # bitsandbytes: bitsandbytes: quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: False \ No newline at end of file diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index fa11fe3a..dfd5fbc8 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -96,6 +96,9 @@ def __init__(self, configurations: Dict[str, Dict]): self._quant_type = self._check_config_and_maybe_check_values( key="peft.quantization.bitsandbytes.quant_type", values=["fp4", "nf4"] ) + self._no_peft_model = self._check_config_and_maybe_check_values( + key="peft.quantization.bitsandbytes.no_peft_model", values=[True, False] + ) def model_loader(self, model_name: str, **kwargs): @@ -121,6 +124,16 @@ def model_loader(self, model_name: str, **kwargs): "If running in FSDP, this is probably because accelerate is not used. " "This will most probably result in error." ) + elif ( + world_size == 1 + and self._no_peft_model == True + ): + warnings.warn( + """Running on single device and setting plugin config `no_peft_model` as `True` + PEFT preparation will be managed by SFTTrainer and will cause a slowdown in training speed + due to extraneous dtype casting when SFTTrainer prepares the model using + https://github.com/huggingface/trl/blob/e90e8d91d2265e484f229c45a5eb8982f94a2936/trl/trainer/sft_trainer.py#L210""" + ) bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -147,7 +160,8 @@ def requires_custom_loading(self): @property def requires_agumentation(self): - return True + # will skip the augmentation if _no_peft_model == True + return not self._no_peft_model def augmentation( self, diff --git a/plugins/accelerated-peft/tests/test_peft_plugins.py b/plugins/accelerated-peft/tests/test_peft_plugins.py index 894e1ca6..42404ddc 100644 --- a/plugins/accelerated-peft/tests/test_peft_plugins.py +++ b/plugins/accelerated-peft/tests/test_peft_plugins.py @@ -122,6 +122,20 @@ def test_configure_bnb_plugin(): assert framework.requires_agumentation assert len(framework.get_callbacks_and_ready_for_train()) == 0 + # test no_peft_model is true skips plugin.augmentation + for key, correct_value in [ + ("peft.quantization.bitsandbytes.no_peft_model", True), + ("peft.quantization.bitsandbytes.no_peft_model", False), + ]: + with instantiate_framework( + update_configuration_contents( + read_configuration(CONFIG_PATH_BNB), key, correct_value + ), + require_packages_check=False, + ): + # check flags and callbacks + assert (not correct_value)==framework.requires_agumentation + # attempt to activate plugin with configuration pointing to wrong path # - raise with message that no plugins can be configured with pytest.raises(ValueError) as e: diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index 33c24253..8d45bedf 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -14,4 +14,9 @@ framework_configs: - shortname: accelerated-peft-bnb plugins: - accelerated-peft - filename: accelerated-peft-bnb-nf4-sample-configuration.yaml \ No newline at end of file + filename: accelerated-peft-bnb-nf4-sample-configuration.yaml + + - shortname: baseline-peft-bnb + plugins: + - accelerated-peft + filename: baseline-peft-bnb-nf4-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml index e920931c..19fb71fb 100644 --- a/sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml @@ -18,3 +18,7 @@ plugins: # bitsandbytes: bitsandbytes: quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: false diff --git a/sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml b/sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml new file mode 100644 index 00000000..244de5e7 --- /dev/null +++ b/sample-configurations/baseline-peft-bnb-nf4-sample-configuration.yaml @@ -0,0 +1,24 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: + bitsandbytes: + quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: true diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml index e79a74e6..21b3f98a 100644 --- a/scripts/benchmarks/scenarios.yaml +++ b/scripts/benchmarks/scenarios.yaml @@ -32,6 +32,23 @@ scenarios: - 'mistralai/Mixtral-8x7B-Instruct-v0.1' - 'NousResearch/Llama-2-70b-hf' + - name: baseline-peft-bnb + framework_config: + - baseline-peft-bnb + arguments: + fp16: True + learning_rate: 2e-4 + torch_dtype: float16 + peft_method: lora + r: 16 + lora_alpha: 16 + lora_dropout: 0.0 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + - 'mistralai/Mixtral-8x7B-Instruct-v0.1' + - 'NousResearch/Llama-2-70b-hf' + - name: accelerated-peft-bnb framework_config: - accelerated-peft-bnb diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index 9f239041..67ad4058 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -141,6 +141,7 @@ def read_configuration(path: str) -> Dict: # specified key path, with the value. KEY_AUTO_GPTQ = "auto_gptq" KEY_BNB_NF4 = "bnb-nf4" +KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -148,6 +149,13 @@ def read_configuration(path: str) -> Dict: "plugins/accelerated-peft/configs/bnb.yaml", [("peft.quantization.bitsandbytes.quant_type", "nf4")], ), + KEY_BNB_NF4_BASELINE: ( + "plugins/accelerated-peft/configs/bnb.yaml", + [ + ("peft.quantization.bitsandbytes.quant_type", "nf4"), + ("peft.quantization.bitsandbytes.no_peft_model", True), + ], + ), } # list of (tag, combi) tuples @@ -158,6 +166,7 @@ def read_configuration(path: str) -> Dict: COMBINATIONS = [ ("accelerated-peft-autogptq", (KEY_AUTO_GPTQ,)), ("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)), + ("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)), ]