Skip to content

Commit

Permalink
improve docs for tuning techniques
Browse files Browse the repository at this point in the history
  • Loading branch information
Ssukriti committed Apr 8, 2024
1 parent 2a0c4d3 commit cacea88
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 13 deletions.
109 changes: 96 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pip install -e .
```

> Note: After installing, if you wish to use [FlashAttention](https://github.com/Dao-AILab/flash-attention), then you need to install these requirements:
'''
```
pip install -e ".[dev]"
pip install -e ".[flash-attn]"
```
Expand Down Expand Up @@ -111,8 +111,6 @@ The recommendation is to use [huggingface accelerate](https://huggingface.co/doc
# TRAIN_DATA_PATH=twitter_complaints.json # Path to the training dataset
# OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved


```bash
accelerate launch \
--main_process_port $MASTER_PORT \
--config_file fixtures/accelerate_fsdp_defaults.yaml \
Expand Down Expand Up @@ -140,10 +138,31 @@ tuning/sft_trainer.py \
--dataset_text_field "output"
```

To summarize you can pick either python for singleGPU jobs or use accelerate launch for multiGPU jobs. The following tuning techniques can be applied:

## Tuning Techniques :

### LoRA Tuning Example

Set peft_method = "lora". You can additionally pass any arguments from [LoraConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L7).
```
r: int =8
lora_alpha: int = 32
target_modules: List[str] = field(
default_factory=lambda: ["q_proj", "v_proj"],
metadata={
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
completely match or "
'end with one of the strings. If the value is ["all-linear"], \
then LORA selects all linear and Conv1D '
"modules except for the output layer."
},
)
bias = "none"
lora_dropout: float = 0.05
```

```bash
python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
Expand Down Expand Up @@ -173,16 +192,6 @@ python tuning/sft_trainer.py \
--lora_alpha 16
```

where [`LoraConfig`](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L7) that is being set looks like:
```py
LoraConfig(
r=8,
lora_alpha=16,
target_modules=['q_proj', 'v_proj'],
lora_dropout=0.05
)
```

Notice the `target_modules` that are set are the default values. `target_modules` are the names of the modules to apply the adapter to. If this is specified, only the modules with the specified names will be replaced. When passing a list of strings, either an exact match will be performed or it is checked if the name of the module ends with any of the passed strings. If this is specified as `all-linear`, then all linear/Conv1D modules are chosen, excluding the output layer. If this is not specified, modules will be chosen according to the model architecture. If the architecture is not known, an error will be raised — in this case, you should specify the target modules manually. See [HuggingFace docs](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig) for more details.

For each model, the `target_modules` will depend on the type of model architecture. You can specify linear or attention layers to `target_modules`. To obtain list of `target_modules` for a model:
Expand Down Expand Up @@ -238,6 +247,80 @@ For example for LLaMA model the modules look like:

You can specify attention or linear layers. With the CLI, you can specify layers with `--target_modules "q_proj" "v_proj" "k_proj" "o_proj"` or `--target_modules "all-linear"`.

### Prompt Tuning :

Specify peft_method to 'pt' . You can additionally pass any arguments from [PromptTuningConfig](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L39).
```
# prompt_tuning_init can be either "TEXT" or "RANDOM"
prompt_tuning_init: str = "TEXT"
num_virtual_tokens: int = 8
# prompt_tuning_init_text only applicable if prompt_tuning_init= "TEXT"
prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:"
tokenizer_name_or_path: str = "llama-7b-hf"
```

```bash

accelerate launch \
--main_process_port $MASTER_PORT \
--config_file fixtures/accelerate_fsdp_defaults.yaml \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--training_data_path $TRAIN_DATA_PATH \
--output_dir $OUTPUT_PATH \
--peft_method pt \
--torch_dtype bfloat16 \
--tokenizer_name_or_path $MODEL_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--include_tokens_per_second \
--packing False \
--response_template "\n### Label:" \
--dataset_text_field "output"
```

### Fine Tuning :

Set peft_method = 'None'

```bash

accelerate launch \
--main_process_port $MASTER_PORT \
--config_file fixtures/accelerate_fsdp_defaults.yaml \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--training_data_path $TRAIN_DATA_PATH \
--output_dir $OUTPUT_PATH \
--peft_method "None" \
--torch_dtype bfloat16 \
--tokenizer_name_or_path $MODEL_PATH \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--include_tokens_per_second \
--packing False \
--response_template "\n### Label:" \
--dataset_text_field "output"
```

## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

Expand Down
30 changes: 30 additions & 0 deletions tuning/config/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@

@dataclass
class LoraConfig:
"""
This is the configuration class to store the configuration of a [`LoraModel`].
Args:
r (`int`):
Lora attention dimension (the "rank").
target_modules (List[str]]):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. Please specify modules as per model architecture. If the value is ["all-linear"], \
then LORA selects all linear and Conv1D modules as per model architecture, except for the output layer.
lora_alpha (`int`):
The alpha parameter for Lora scaling.
lora_dropout (`float`):
The dropout probability for Lora layers.
bias (`str`):
Bias type for LoRA. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the corresponding biases
will be updated during training. Be aware that this means that, even when disabling the adapters, the model
will not produce the same output as the base model would have without adaptation.
"""
r: int = 8
lora_alpha: int = 32
target_modules: List[str] = field(
Expand All @@ -37,6 +56,17 @@ class LoraConfig:

@dataclass
class PromptTuningConfig:
"""
This is the configuration class for Prompt Tuning.
Args:
prompt_tuning_init : str: The initialization of the prompt embedding. Allowed values "TEXT" or "RANDOM".
prompt_tuning_init_text (`str`, *optional*):
The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`.
tokenizer_name_or_path (`str`, *optional*):
The name or path of the tokenizer. Only used if `prompt_tuning_init` is `TEXT`.
num_virtual_tokens (`int`): The number of virtual tokens to use.
"""
prompt_tuning_init: str = "TEXT"
num_virtual_tokens: int = 8
prompt_tuning_init_text: str = "Classify if the tweet is a complaint or not:"
Expand Down

0 comments on commit cacea88

Please sign in to comment.