Skip to content

Commit

Permalink
docs: lora and getting modules list (#46)
Browse files Browse the repository at this point in the history
* add docs for lora and getting modules

Signed-off-by: Anh-Uong <[email protected]>

* Apply suggestions from code review

Co-authored-by: Sukriti Sharma <[email protected]>
Signed-off-by: Anh Uong <[email protected]>

---------

Signed-off-by: Anh-Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Co-authored-by: Sukriti Sharma <[email protected]>
  • Loading branch information
anhuong and Ssukriti authored Mar 7, 2024
1 parent 0d07ee2 commit 9c5a3bf
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,102 @@ tuning/sft_trainer.py \

For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply replace the `'LlamaDecoderLayer'` with `'GPTBigCodeBlock'` in `tuning/config/fsdp_config.json` for proper sharding of the model.

### LoRA Tuning Example

```bash
python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
--output_dir $OUTPUT_PATH \
--num_train_epochs 40 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--save_strategy "epoch" \
--learning_rate 1e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--include_tokens_per_second \
--packing False \
--response_template "\n### Label:" \
--dataset_text_field "output" \
--use_flash_attn False \
--tokenizer_name_or_path $MODEL_PATH \
--torch_dtype float32 \
--peft_method "lora" \
--logging_strategy "epoch" \
--r 8 \
--lora_dropout 0.05 \
--lora_alpha 16
```

where [`LoraConfig`](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L7) that is being set looks like:
```py
LoraConfig(
r=8,
lora_alpha=16,
target_modules=['q_proj', 'v_proj'],
lora_dropout=0.05
)
```

Notice the `target_modules` that are set are the default values. `target_modules` are the names of the modules to apply the adapter to. If this is specified, only the modules with the specified names will be replaced. When passing a list of strings, either an exact match will be performed or it is checked if the name of the module ends with any of the passed strings. If this is specified as `all-linear`, then all linear/Conv1D modules are chosen, excluding the output layer. If this is not specified, modules will be chosen according to the model architecture. If the architecture is not known, an error will be raised — in this case, you should specify the target modules manually. See [HuggingFace docs](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig) for more details.

For each model, the `target_modules` will depend on the type of model architecture. You can specify linear or attention layers to `target_modules`. To obtain list of `target_modules` for a model:

```py
from transformers import AutoModelForCausalLM
# load the model
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
# see the module list
model.modules

# to get just linear layers
import re
model_modules = str(model.modules)
pattern = r'\((\w+)\): Linear'
linear_layer_names = re.findall(pattern, model_modules)

names = []
for name in linear_layer_names:
names.append(name)
target_modules = list(set(names))
```

For example for LLaMA model the modules look like:
```
<bound method Module.modules of LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096, padding_idx=0)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)>
```

You can specify attention or linear layers. With the CLI, you can specify layers with `--target_modules "q_proj" "v_proj" "k_proj" "o_proj"` or `--target_modules "all-linear"`.

## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

Expand Down

0 comments on commit 9c5a3bf

Please sign in to comment.