Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example: use skypilot for finetune #132

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/skypilot/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Skypilot

[SkyPilot](https://github.com/skypilot-org/skypilot) is a framework for easily running machine learning workloads on any cloud through a unified interface which makes it perfect for qlora finetunes.

## Usage

# use pip install "skypilot[gcp,aws]" for whatever cloud you want to support
pip install "skypilot"

# make sure that sky check returns green for some providers
./skypilot.sh

This should give you something like this, depending on your cloud and settings and parameters:

./skypilot.sh --cloud lambda --gpu H100:1
Task from YAML spec: qlora.yaml
== Optimizer ==
Target: minimizing cost
Estimated cost: $2.4 / hour
Considered resources (1 node):
------------------------------------------------------------------------------------------------
CLOUD INSTANCE vCPUs Mem(GB) ACCELERATORS REGION/ZONE COST ($) CHOSEN
------------------------------------------------------------------------------------------------
Lambda gpu_1x_h100_pcie 26 200 H100:1 us-east-1 2.40 ✔
------------------------------------------------------------------------------------------------
Launching a new cluster 'qlora'. Proceed? [Y/n]: y


Other, very sensible things to do are to pass --idle-minutes-to-autostop 60 so that the cluster shuts down after it's done. If your cloud provider supports spot instances than --use-spot can be ideal.

Make sure that you either mount a /outputs directory or setup an automated upload to a cloud bucket after the training is done.
141 changes: 141 additions & 0 deletions examples/skypilot/qlora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
name: qlora

resources:
# any ampere+ works well, but since this is an example,
accelerators: A100:4

#add this on gcp:
disk_size: 100
disk_tier: 'high'

num_nodes: 1

#file_mounts:
# uplaod the latest training dataset if you have your own
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upload

This comment was marked as a violation of GitHub Acceptable Use Policies
# and then specifiy DATASET and DATASET_FORMAT below to match
# /data/train.jsonl: ./train.jsonl

# mount a bucket for saving results to
# /outputs:
# name: outputs
# mode: MOUNT

setup: |

# Setup the environment
conda create -n qlora python=$PYTHON -y
conda activate qlora

pip install -U torch

git clone https://github.com/tobi/qlora.git
cd qlora
pip install -U -r requirements.txt

# periodic checkpoints go here
mkdir -p ~/local-checkpoints

run: |

# Activate the environment
conda activate qlora
cd qlora

# let's double check that the output bucket exists,
# otherwise this trainig run will be for nothing
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
LOCAL_CHECKPOINTS=~/local-checkpoints

echo "batch side: $PER_DEVICE_BATCH_SIZE"
echo "gradient steps: $GRADIENT_ACCUMULATION_STEPS"

# Turn off wandb if no api key is provided,
# add it with --env WANDB=xxx parameter to sky launch
if [ $WANDB_API_KEY == "" ]; then
WANDB_MODE="offline"
fi

# Run the training through torchrun for
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=$SKYPILOT_NODE_RANK \
qlora.py \
--model_name_or_path $MODEL_NAME \
--output_dir $LOCAL_CHECKPOINTS \
--logging_steps 10 \
--save_strategy steps \
--data_seed 42 \
--save_steps 500 \
--save_total_limit 3 \
--evaluation_strategy steps \
--eval_dataset_size 1024 \
--max_eval_samples 1000 \
--max_new_tokens 32 \
--dataloader_num_workers 3 \
--group_by_length \
--logging_strategy steps \
--remove_unused_columns False \
--do_train \
--do_eval \
--do_mmlu_eval \
--lora_r 64 \
--lora_alpha 16 \
--lora_modules all \
--double_quant \
--quant_type nf4 \
--bf16 \
--bits $BITS \
--warmup_ratio 0.03 \
--lr_scheduler_type constant \
--gradient_checkpointing False \
--dataset $DATASET \
--dataset-format $DATASET_FORMAT \
--source_max_len 16 \
--target_max_len 512 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
--max_steps 1875 \
--eval_steps 187 \
--learning_rate 0.0002 \
--adam_beta2 0.999 \
--max_grad_norm 0.3 \
--lora_dropout 0.05 \
--weight_decay 0.0 \
--run_name $SKYPILOT_JOB_ID \
--ddp_find_unused_parameters False \
--report_to wandb \
--seed 0


returncode=$?
RSYNC=rsync
# Sync any files not in the checkpoint-* folders, if we are on gcp use
# gsutil so that we can sync to a gs bucket. You can replace this
# code with anything that puts the model somewhere useful and permanent.
if command -v gsutil &> /dev/null; then
RSYNC="gsutil -m rsync"
fi
$RSYNC -r $LOCAL_CHECKPOINTS/ $OUTPUT_RSYNC_TARGET/
exit $returncode


envs:
PYTHON: "3.10"
CUDA_MAJOR: "12"
CUDA_MINOR: "1"
MODEL_NAME: huggyllama/llama-7B

DATASET: alpaca
DATASET_FORMAT: alpaca

OUTPUT_RSYNC_TARGET: /outputs/qlora

BITS: 4
PER_DEVICE_BATCH_SIZE: 1
GRADIENT_ACCUMULATION_STEPS: 16 # apparently best to be 16x batch size, reduce for lower memory requirement

3 changes: 3 additions & 0 deletions examples/skypilot/skypilot.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sky launch -c qlora qlora.yaml \
--env WANDB_API_KEY=$WANDB_API_KEY \
$@