Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port 24.04 changes into main #851

Merged
merged 12 commits into from
Jun 7, 2024
8 changes: 4 additions & 4 deletions .github/container/test-pax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ usage() {
echo " --dtype Batch size, defaults to bfloat16."
echo " --enable-te If set, will run with env var ENABLE_TE=1."
echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1."
echo " --enable-fused-attn Whether to test fused attention through TE."
echo " --disable-fused-attn Whether disable TE fused attention."
echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M"
echo " --evaluate Whether to test evaluation rather than training."
echo " -s, --steps Number of steps to run, defaults to 500."
Expand Down Expand Up @@ -51,7 +51,7 @@ PP=1
NODES=1
ENABLE_TE=0
MODEL_TYPE=126M
NVTE_FUSED_ATTN=0
NVTE_FUSED_ATTN=1
DROPOUT=0
EVALUATE=0
ADDITIONAL_ARGS=""
Expand Down Expand Up @@ -79,8 +79,8 @@ while [ : ]; do
DROPOUT='0.1'
shift 1
;;
--enable-fused-attn)
NVTE_FUSED_ATTN=1
--disable-fused-attn)
NVTE_FUSED_ATTN=0
shift 1
;;
--model-type)
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/_test_pax_rosetta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ jobs:
- TEST_NAME: 5B_fused_attn_1
PARALLEL_CONFIG: [1, 1, 8, 1]
BATCH_SIZE: 2
ADDITIONAL_ARGS: "--model-type 5B --enable-fused-attn"
ADDITIONAL_ARGS: "--model-type 5B"
- TEST_NAME: 5B_fused_attn_0
PARALLEL_CONFIG: [1, 1, 8, 1]
BATCH_SIZE: 2
ADDITIONAL_ARGS: "--model-type 5B"
ADDITIONAL_ARGS: "--model-type 5B --disable-fused-attn"
- TEST_NAME: LLaMA_eval_TE
PARALLEL_CONFIG: [1, 1, 8, 1]
BATCH_SIZE: 4
Expand Down
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,18 @@ nightly build of the container for `XXX`. These containers are also tagged as
## Note
This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: [T5x](https://github.com/google-research/t5x), [PAXML](https://github.com/google/paxml), [Transformer Engine](https://github.com/NVIDIA/TransformerEngine), [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html) and others to come soon.

## Supported Models
We currently enable training and evaluation for the following models:
| Model Name | Pretraining | Fine-tuning | Evaluation |
## Frameworks and Supported Models
We currently support the following frameworks and models. More details about each model and the available containers can be found in their respective READMEs.

| Framework | Supported Models | Use-cases | Container |
| :--- | :---: | :---: | :---: |
| [GPT-3(paxml)](./rosetta/rosetta/projects/pax) | ✔️ | | ✔️ |
| [LLaMA2(paxml)](./rosetta/rosetta/projects/pax#llama) | | | ✔️ |
| [t5(t5x)](./rosetta/rosetta/projects/t5x) | ✔️ | ✔️ | ✔️ |
| [ViT](./rosetta/rosetta/projects/vit) | ✔️ | ✔️ | ✔️ |
| [Imagen](./rosetta/rosetta/projects/imagen) | ✔️ | | ✔️ |
| [PaliGemma](./rosetta/rosetta/projects/paligemma) | | ✔️ | ✔️ |
| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `nvcr.io/nvidia/jax:24.04-paxml-py3` |
terrykong marked this conversation as resolved.
Show resolved Hide resolved
| [T5X](./rosetta/rosetta/projects/t5x) | T5 | pre-training, fine-tuning | `nvcr.io/nvidia/jax:23.10-paxml-py3` |
| [T5X](./rosetta/rosetta/projects/vit) | ViT | pre-training, fine-tuning | `ghcr.io/nvidia/t5x:vit-2023-07-21` |
| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02` |
| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter-2024-04-16` |
| maxtext| LLaMA, Gemma | pretraining | `nvcr.io/nvidia/jax:24.04-maxtext-py3` |

We will update this table as new models become available, so stay tuned.

Expand Down
1 change: 1 addition & 0 deletions rosetta/Dockerfile.pax
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ARG UPDATE_PATCHES
ARG UPDATED_TE_REF

ENV ENABLE_TE=1
ENV NVTE_FUSED_ATTN=1

RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu
MANIFEST_DIR=$(dirname ${MANIFEST_FILE})
Expand Down
240 changes: 163 additions & 77 deletions rosetta/rosetta/projects/pax/README.md

Large diffs are not rendered by default.

28 changes: 19 additions & 9 deletions rosetta/rosetta/projects/pax/scripts/example_slurm_llama.sub
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash
#SBATCH -A example # slurm account
#SBATCH -p partition # slurm partition name
#SBATCH -N 1 # number of nodes
#SBATCH -t 00:20:00 # wall time
#SBATCH -N 2 # number of nodes
#SBATCH -t 01:00:00 # wall time
#SBATCH -J "paxml:test" # job name
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
Expand All @@ -28,13 +28,13 @@ set -eux

# File system and volume glue code
#-------------------------------------------------------------------------------
CONTAINER="${CONTAINER:-ghcr.io/nvidia/jax:pax-llama2-2024-02-07}"
CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:24.04-paxml-py3}"

# << CHANGE ! >>
BASE_WORKSPACE_DIR=${BASE_WORKSPACE_DIR} ## location to write logs and checkpoints to
BASE_TFDS_DATA_DIR=${BASE_TFDS_DATA_DIR}
BASE_VOCAB_PATH=${BASE_VOCAB_PATH}
BASE_CHECKPOINT_DIR=${BASE_CHECKPOINT_DIR}
BASE_CHECKPOINT_RESTORE_PATH=${BASE_CHECKPOINT_RESTORE_PATH}
PAXML_DIR=${PAXML_DIR:-/opt/paxml}

# Default env variables for paths required by pax training scripts
Expand All @@ -44,7 +44,7 @@ GPT_VOCAB_PATH=/mnt/vocab
CHECKPOINT_DIR=/opt/paxml/workspace/llama-checkpoint

# Add the pax/JAX specific mounts
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_VOCAB_PATH:$GPT_VOCAB_PATH,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_CHECKPOINT_DIR:$CHECKPOINT_DIR"
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_VOCAB_PATH:$GPT_VOCAB_PATH,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_CHECKPOINT_RESTORE_PATH:$CHECKPOINT_DIR"

# Make directories that may not exist
mkdir -p $BASE_WORKSPACE_DIR
Expand All @@ -53,16 +53,26 @@ EXPORTS="--export=ALL"
#-------------------------------------------------------------------------------

OUTPUT_DIR=${OUTPUT_DIR:-"output"}
PREC=${PREC}
GPUS_PER_NODE=${GPUS_PER_NODE}
PERCORE_BATCH_SIZE=${PERCORE_BATCH_SIZE}
CONFIG=${CONFIG:-LLaMA7B}
LOG_DIR_LOCAL=${LOG_DIR_LOCAL}
USE_LORA=${USE_LORA:-False}
ENABLE_TE=${ENABLE_TE:-1}
EVAL_ONLY=${EVAL_ONLY:-0}

cmd="$(cat <<EOF
echo "*******STARTING********"
cd ${PAXML_DIR}
nvidia-smi
bash paxml/contrib/gpu/scripts_gpu/run_llama_boolq_multiprocess.sh $TFDS_DATA_DIR $GPT_VOCAB_PATH $PREC $GPUS_PER_NODE $PERCORE_BATCH_SIZE $CHECKPOINT_DIR $CONFIG
VOCAB_PATH=${VOCAB_PATH} \
TFDS_DATA_DIR=${TFDS_DATA_DIR} \
EVAL_ONLY=${EVAL_ONLY} \
CONFIG=${CONFIG} \
LOG_DIR=${WORKSPACE_DIR}/${LOG_DIR_LOCAL} \
CHECKPOINT_RESTORE_PATH=${CHECKPOINT_DIR} \
USE_MULTIPROCESS=1 \
USE_LORA=${USE_LORA} \
ENABLE_TE=${ENABLE_TE} \
bash paxml/contrib/gpu/scripts_gpu/run_llama_boolq.sh
EOF
)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ set -eux
# File system and volume glue code
#-------------------------------------------------------------------------------
# << CHANGE ! >>
CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:23.10-paxml-py3}"
CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:24.04-paxml-py3}"

# << CHANGE ! >>
BASE_WORKSPACE_DIR=${BASE_WORKSPACE_DIR} ## location to write logs and checkpoints to
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/bin/bash
#SBATCH -A example # slurm account
#SBATCH -p partition # slurm partition name
#SBATCH -N 8 # number of nodes
#SBATCH -t 00:30:00 # wall time
#SBATCH -N 2 # number of nodes
#SBATCH -t 01:00:00 # wall time
#SBATCH -J "paxml:test" # job name
#SBATCH --exclusive # exclusive node access
#SBATCH --mem=0 # all mem avail
#SBATCH --ntasks-per-node=8 # n tasks per machine (one task per gpu)
#SBATCH --overcommit
#SBATCH --overcommit
#SBATCH --dependency=singleton # only run one instance at a time
set -eux

Expand All @@ -28,47 +28,51 @@ set -eux

# File system and volume glue code
#-------------------------------------------------------------------------------
# << CHANGE ! >>
CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:23.10-paxml-py3}"
CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:24.04-paxml-py3}"

# << CHANGE ! >>
BASE_WORKSPACE_DIR=${BASE_WORKSPACE_DIR} ## location to write logs and checkpoints to
BASE_TFDS_DATA_DIR=${BASE_TFDS_DATA_DIR}
BASE_VOCAB_PATH=${BASE_VOCAB_PATH}
BASE_CHECKPOINT_RESTORE_PATH=${BASE_CHECKPOINT_RESTORE_PATH}
PAXML_DIR=${PAXML_DIR:-/opt/paxml}
ENABLE_TE=${ENABLE_TE:-1}
ENABLE_FP8=${ENABLE_FP8:-0}

# Default env variables for paths required by pax training scripts
WORKSPACE_DIR=/opt/paxml/workspace
TFDS_DATA_DIR=/mnt/datasets
GPT_VOCAB_PATH=/mnt/vocab
CHECKPOINT_DIR=/opt/paxml/workspace/llama-checkpoint

# Add the pax/JAX specific mounts
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR"
MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_VOCAB_PATH:$GPT_VOCAB_PATH,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_CHECKPOINT_RESTORE_PATH:$CHECKPOINT_DIR"

# Make directories that may not exist
mkdir -p $BASE_WORKSPACE_DIR

EXPORTS="--export=ALL"
#-------------------------------------------------------------------------------

CONFIG=${CONFIG:-"Synthetic126M"}
OUTPUT_DIR=${OUTPUT_DIR:-"output"}
PREC=${PREC}
GPUS_PER_NODE=${GPUS_PER_NODE}
PERCORE_BATCH_SIZE=${PERCORE_BATCH_SIZE}
CONFIG=${CONFIG:-LLaMA7B}
LOG_DIR_LOCAL=${LOG_DIR_LOCAL}

if [[ -z "${BASE_SCRIPT:-}" ]]; then
export BASE_SCRIPT="${PAXML_DIR}/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh"
echo "Using default BASE_SCRIPT=$BASE_SCRIPT"
else
export BASE_SCRIPT="${WORKSPACE_DIR}/${BASE_SCRIPT}"
echo "Using custom BASE_SCRIPT=$BASE_SCRIPT"
fi
USE_LORA=${USE_LORA:-False}
ENABLE_TE=${ENABLE_TE:-1}
EVAL_ONLY=${EVAL_ONLY:-0}

cmd="$(cat <<EOF
echo "*******STARTING********"
cd ${PAXML_DIR}
nvidia-smi
ENABLE_TE=$ENABLE_TE ENABLE_FP8=$ENABLE_FP8 bash $BASE_SCRIPT $CONFIG $PREC $GPUS_PER_NODE $PERCORE_BATCH_SIZE ${WORKSPACE_DIR}/${LOG_DIR_LOCAL}
VOCAB_PATH=${VOCAB_PATH} \
TFDS_DATA_DIR=${TFDS_DATA_DIR} \
EVAL_ONLY=${EVAL_ONLY} \
CONFIG=${CONFIG} \
LOG_DIR=${WORKSPACE_DIR}/${LOG_DIR_LOCAL} \
CHECKPOINT_RESTORE_PATH=${CHECKPOINT_DIR} \
USE_MULTIPROCESS=1 \
USE_LORA=${USE_LORA} \
ENABLE_TE=${ENABLE_TE} \
bash paxml/contrib/gpu/scripts_gpu/run_llama_boolq.sh
EOF
)"

Expand Down
55 changes: 45 additions & 10 deletions rosetta/utils/te_pax_t5x_ckpt_converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,44 @@

### Arguments
```bash
-h, --help show this help message and exit
-h, --help
show this help message and exit
--input-path INPUT_PATH
the path to load a source checkponint for this conversion. (Required)
the path to load a source checkponint for this conversion. (Required)
--output-path OUTPUT_PATH
the path to store the converted checkponint. (Required)
--fw {pax,t5x} the framework that stored the given source checkpoint. (Required)
--fw {pax,t5x}
the framework that stored the given source checkpoint. (Required)
--direction {fw2te,te2fw}
the direction of this conversion. (Required)
--num-of-layer NUM_OF_LAYER
the number of Transformer layer of the given source checkpoint. (Required)
--num-of-head NUM_OF_HEAD
the number of head of multi-head attention of the given source checkpoint. (Required)
--head-dim HEAD_DIM the head dimension of multi-head attention of the given source checkpoint. (Required)
--num-gqa-groups NUM_GQA_GROUPS
the number of GQA groups (key-value heads) of the given source checkpoint. Required for --te-qkv-layout=kv_packed.
--head-dim HEAD_DIM
the head dimension of multi-head attention of the given source checkpoint. (Required)
--mlp-intermediate-dim MLP_INTERMEDIATE_DIM
the intermediate dimension of MLP block (FFN) of the given source checkpoint. (Required)
--embed-dim EMBED_DIM
the embeded dimension of the given source checkpoint, must give if --fw=t5x. (default: None)
--kernel-chunk-size KERNEL_CHUNK_SIZE
the size to chucnk kernel (weighs) then store, only support with --fw=pax. Setting None means no chunking. (default: None)
--weight-only indicate if the source checkpoint only includes weights. (default: False)
--skip-ln indicate if skip the conversion for LayerNorm. (default: False)
--pax-repeat indicate if the source Pax checkpoint enables Repeat. (default: False)
--t5x-fuse-qkv indicate if the source T5X checkpoint enables fused_qkv_params of TE. (default: False)
--weight-only
indicate if the source checkpoint only includes weights. (default: False)
--skip-ln
indicate if skip the conversion for LayerNorm. (default: False)
--use-gated-activations
indicate if the checkpointed model uses a gated activation function. (default: False)
--pax-repeat
indicate if the source Pax checkpoint enables Repeat. (default: False)
--pax-split-qkv
indicate if the source Pax checkpoint has split QKV parameters. Required for --te-qkv-layout=kv_packed. (default: False)
--t5x-fuse-qkv
indicate if the source T5X checkpoint enables fused_qkv_params of TE. (default: False)
--te-qkv-layout {qkv_packed,kv_packed}
indicate the QKV layout of the converted TE checkpoint. Only supported with --pax-split-qkv and requires --num-gqa-heads <N> to be set. (default: 'qkv_packed')
```

### Usage Examples
Expand Down Expand Up @@ -71,7 +86,7 @@ python converter/main.py \
--mlp-intermediate-dim=1024
```

2. Pax -> TE (Not Repeat):
4. Pax -> TE (Not Repeat):
```bash
python converter/main.py \
--input-path=/your_path_to_src_ckpt \
Expand All @@ -84,6 +99,26 @@ python converter/main.py \
--mlp-intermediate-dim=1024
```

5. Pax -> TE (LLaMa2-70b):
```bash
python converter/main.py \
--input-path=/your_path_to_src_ckpt \
--output-path=/your_path_to_output_ckpt \
--fw=pax \
--direction=fw2te \
--num-of-layer=80 \
--num-of-head=64 \
--num-gqa-groups=8 \
--head-dim=128 \
--mlp-intermediate-dim=28672 \
--kernel-chunk-size=512 \
--skip-bias \
--weight-only \
--use-gated-activations \
--pax-split-qkv \
--te-qkv-layout=kv_packed
```

#### T5X
1. TE/FusedQKV -> T5X:
```bash
Expand Down Expand Up @@ -154,7 +189,7 @@ restoring it to keep training.

#### The folder structure of CKPT by Pax and T5X
If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder
structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the
structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the
CKPTs from frameworks, and no need to pre-generate folders, since it would be generated when needed.
For Pax, you could set `--output-path` be like ` /${your_path_to_output}/checkpoints/checkpoint_${step}`.
For T5X, you could set `--output-path` be like `/${your_path_to_output}/checkpoint_${step}`.
Loading
Loading