Skip to content

Commit

Permalink
feat: move to accelerate for distributed training launch
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Mar 13, 2024
1 parent f4d9441 commit 338130a
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 25 deletions.
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ Current supported and tested models are `Llama2` (7 and 13B configurations have
# if you want to use one GPU on multi-gpu machine
export CUDA_VISIBLE_DEVICES=0

MODEL_PATH=llama-7b-hf # Huggingface model id or path to a checkpoint
DATA_PATH=twitter_complaints.json # Path to the dataset
OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved

python tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
Expand All @@ -83,11 +87,16 @@ python tuning/sft_trainer.py \
```

### Multiple GPUs with FSDP

```bash
torchrun \
--nnodes=1 \
--nproc_per_node=8 \
--master_port=1234 \
MODEL_PATH=llama-7b-hf # Huggingface model id or path to a checkpoint
DATA_PATH=twitter_complaints.json # Path to the dataset
OUTPUT_PATH=out # Path to the output folder where the checkpoints are saved
MASTER_PORT=1234 # The port at which the process with rank 0 listens to
MASTER_ADDR=x.x.x.x # The IP addresss of the node with rank 0

accelerate launch --main_process_ip $MASTER_ADDR --main_process_port $MASTER_PORT \
--config_file config/accelerate_fsdp_llama_2_procs.yaml \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
Expand All @@ -104,16 +113,14 @@ tuning/sft_trainer.py \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_config tuning/config/fsdp_config.json \
--include_tokens_per_second \
--packing False \
--response_template "\n### Response:" \
--dataset_text_field "output"
```


For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply replace the `'LlamaDecoderLayer'` with `'GPTBigCodeBlock'` in `tuning/config/fsdp_config.json` for proper sharding of the model.
Typically the transformer module is passed to form FSDP unit. For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply replace the `'LlamaDecoderLayer'` with `'GPTBigCodeBlock'` in `config/accelerate_fsdp_llama_2_procs.yaml` for proper sharding of the model.

### LoRA Tuning Example

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# options that can be used with accelerate config are neatly documented here -
# https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/docs/source/package_reference/cli.md

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
Expand Down
10 changes: 4 additions & 6 deletions examples/prompt_tuning_twitter_complaints/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ We will switch our PEFT method from LORA to Prompt Tuning (pt)
MODEL_PATH=llama-7b-hf
DATA_PATH=twitter_complaints.json
OUTPUT_PATH=out
MASTER_PORT=1234 # The port at which the process with rank 0 listens to
MASTER_ADDR=x.x.x.x # The IP addresss of the node with rank 0

torchrun \
--nnodes=1 \
--nproc_per_node=8 \
--master_port=1234 \
accelerate launch --main_process_ip $MASTER_ADDR --main_process_port $MASTER_PORT \
--config_file config/accelerate_fsdp_llama_2_procs.yaml \
tuning/sft_trainer.py \
--model_name_or_path $MODEL_PATH \
--data_path $DATA_PATH \
Expand All @@ -56,8 +56,6 @@ tuning/sft_trainer.py \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_config tuning/config/fsdp_config.json \
--include_tokens_per_second \
--packing False \
--response_template "\n### Label:" \
Expand Down
12 changes: 0 additions & 12 deletions tuning/config/fsdp_config.json

This file was deleted.

0 comments on commit 338130a

Please sign in to comment.