diff --git a/.github/container/test-pax.sh b/.github/container/test-pax.sh index 91cb926b4..2b33f53f7 100755 --- a/.github/container/test-pax.sh +++ b/.github/container/test-pax.sh @@ -17,7 +17,7 @@ usage() { echo " --dtype Batch size, defaults to bfloat16." echo " --enable-te If set, will run with env var ENABLE_TE=1." echo " --enable-dropout If set, will set DROPOUT_PROB to 0.1." - echo " --enable-fused-attn Whether to test fused attention through TE." + echo " --disable-fused-attn Whether disable TE fused attention." echo " --model-type One of 126M, 5B, LLaMA70BProxy. Defaults to 126M" echo " --evaluate Whether to test evaluation rather than training." echo " -s, --steps Number of steps to run, defaults to 500." @@ -32,7 +32,7 @@ usage() { exit $1 } -args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,enable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") +args=$(getopt -o a:b:s:o:n:h --long additional-args:,batch-per-gpu:,dtype:,enable-te,enable-dropout,disable-fused-attn,model-type:,evaluate,steps:,help,multiprocess,output:,data-parallel:,fsdp:,tensor-parallel:,pipeline-parallel:,nodes: -- "$@") if [[ $? -ne 0 ]]; then exit $1 fi @@ -51,7 +51,7 @@ PP=1 NODES=1 ENABLE_TE=0 MODEL_TYPE=126M -NVTE_FUSED_ATTN=0 +NVTE_FUSED_ATTN=1 DROPOUT=0 EVALUATE=0 ADDITIONAL_ARGS="" @@ -79,8 +79,8 @@ while [ : ]; do DROPOUT='0.1' shift 1 ;; - --enable-fused-attn) - NVTE_FUSED_ATTN=1 + --disable-fused-attn) + NVTE_FUSED_ATTN=0 shift 1 ;; --model-type) diff --git a/.github/workflows/_test_pax_rosetta.yaml b/.github/workflows/_test_pax_rosetta.yaml index 7b80bfa60..264777e15 100644 --- a/.github/workflows/_test_pax_rosetta.yaml +++ b/.github/workflows/_test_pax_rosetta.yaml @@ -248,11 +248,11 @@ jobs: - TEST_NAME: 5B_fused_attn_1 PARALLEL_CONFIG: [1, 1, 8, 1] BATCH_SIZE: 2 - ADDITIONAL_ARGS: "--model-type 5B --enable-fused-attn" + ADDITIONAL_ARGS: "--model-type 5B" - TEST_NAME: 5B_fused_attn_0 PARALLEL_CONFIG: [1, 1, 8, 1] BATCH_SIZE: 2 - ADDITIONAL_ARGS: "--model-type 5B" + ADDITIONAL_ARGS: "--model-type 5B --disable-fused-attn" - TEST_NAME: LLaMA_eval_TE PARALLEL_CONFIG: [1, 1, 8, 1] BATCH_SIZE: 4 diff --git a/README.md b/README.md index 48063b51c..6b380399f 100644 --- a/README.md +++ b/README.md @@ -279,16 +279,17 @@ nightly build of the container for `XXX`. These containers are also tagged as ## Note This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: [T5x](https://github.com/google-research/t5x), [PAXML](https://github.com/google/paxml), [Transformer Engine](https://github.com/NVIDIA/TransformerEngine), [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html) and others to come soon. -## Supported Models -We currently enable training and evaluation for the following models: -| Model Name | Pretraining | Fine-tuning | Evaluation | +## Frameworks and Supported Models +We currently support the following frameworks and models. More details about each model and the available containers can be found in their respective READMEs. + +| Framework | Supported Models | Use-cases | Container | | :--- | :---: | :---: | :---: | -| [GPT-3(paxml)](./rosetta/rosetta/projects/pax) | ✔️ | | ✔️ | -| [LLaMA2(paxml)](./rosetta/rosetta/projects/pax#llama) | | | ✔️ | -| [t5(t5x)](./rosetta/rosetta/projects/t5x) | ✔️ | ✔️ | ✔️ | -| [ViT](./rosetta/rosetta/projects/vit) | ✔️ | ✔️ | ✔️ | -| [Imagen](./rosetta/rosetta/projects/imagen) | ✔️ | | ✔️ | -| [PaliGemma](./rosetta/rosetta/projects/paligemma) | | ✔️ | ✔️ | +| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | +| [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | +| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02` | +| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | +| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | +| maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` | We will update this table as new models become available, so stay tuned. diff --git a/rosetta/Dockerfile.pax b/rosetta/Dockerfile.pax index 140ab2b5f..c702aa530 100644 --- a/rosetta/Dockerfile.pax +++ b/rosetta/Dockerfile.pax @@ -24,6 +24,7 @@ ARG UPDATE_PATCHES ARG UPDATED_TE_REF ENV ENABLE_TE=1 +ENV NVTE_FUSED_ATTN=1 RUN --mount=target=/mnt/jax-toolbox,from=jax-toolbox <<"EOF" bash -exu MANIFEST_DIR=$(dirname ${MANIFEST_FILE}) diff --git a/rosetta/rosetta/projects/pax/README.md b/rosetta/rosetta/projects/pax/README.md index 658484f4c..988ca107d 100644 --- a/rosetta/rosetta/projects/pax/README.md +++ b/rosetta/rosetta/projects/pax/README.md @@ -1,22 +1,22 @@ # Pax -[Pax](https://github.com/google/paxml/tree/main) is a framework developed by Google optimized for running machine learning experiments using JAX. Pax consists of the Paxml and [Praxis](https://github.com/google/praxis/tree/main) repositories and is maintained as a [distribution](../../../docs/DEVELOPMENT.md) within Rosetta. This means that we cherry-pick the necessary changes to optimize Pax for GPUs on top of upstream Paxml and Praxis' `main` branches. We also provide support for FP8 training via both [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and native [XLA-FP8](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/NATIVE_FP8.md). +[Pax](https://github.com/google/paxml/tree/main) is a framework developed by Google optimized for running machine learning experiments using JAX. Pax consists of the Paxml and [Praxis](https://github.com/google/praxis/tree/main) repositories and is maintained as a [distribution](../../../docs/DEVELOPMENT.md) within Rosetta. This means that we cherry-pick the necessary changes to optimize Pax for GPUs on top of upstream Paxml and Praxis' `main` branches. We also provide support for FP8 training via both [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) and native [XLA-FP8](../../../docs/NATIVE_FP8.md). Any `paxml/*` or `praxis/*` relative directory/file can be found in [google/paxml](https://github.com/google/paxml/tree/main) or [google/praxis](https://github.com/google/praxis/tree/main), respectively, but to view the most up-to-date version of that directory/file with any GPU-specific patches, please see [Inspecting the Source Code](#inspecting-the-source-code). -## Hardware and Software Specifications +# Hardware and Software Specifications Convergence and performance has been validated on NVIDIA DGX H100 (8x H100 80G) and A100 (8x A100 80G) nodes; for details, please refer to the [Configs](#configs) section below. We provide both singlenode and multinode pre-training support. If running on a machine with less than 80G memory, some of the default configurations may run out of memory; if you run out of memory and have more GPUs available, increase your GPU count and decrease your batch size per GPU. The [NVIDIA Container Toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) is required to run the subsequent commands with GPU support. Ensure the NVIDIA Container Toolkit is installed before proceeding. -## Containers -We provide a fully built and ready-to-use multi-arch container which includes the latest optimizations, experimental features, and examples benchmarked for multi-node, multi-GPU training: `nvcr.io/nvidia/jax:23.10-paxml-py3` (amd64 and arm64 support). This container also provides FP8 support via [Transformer Engine](https://github.com/NVIDIA/TransformerEngine). Verified containers will be updated periodically, but if you wish to use the bleeding edge (which may come with unexpected behavior), please use `ghcr.io/nvidia/pax:latest`. We also provide nightly dated images with the naming pattern `ghcr.io/nvidia/pax:nightly-YYYY-MM-DD`, but we encourage you to use the latest ones for the best performance. +# Containers +We provide a fully built and ready-to-use multi-arch container which includes the latest optimizations, experimental features, and examples benchmarked for multi-node, multi-GPU training: `nvcr.io/nvidia/jax:24.04-paxml-py3` (amd64 and arm64 support). This container also provides FP8 support via [Transformer Engine](https://github.com/NVIDIA/TransformerEngine). Verified containers will be updated periodically, but if you wish to use the bleeding edge (which may come with unexpected behavior), please use `ghcr.io/nvidia/pax:latest`. We also provide nightly dated images with the naming pattern `ghcr.io/nvidia/pax:nightly-YYYY-MM-DD`, but we encourage you to use the latest ones for the best performance. -For more information on the Pax build and for details on how to manually build the Pax distribution, please refer to [DEVELOPMENT.md](../../../docs/DEVELOPMENT.md). +For more information on the Pax build and for details on how to manually build the Pax distribution, please refer to [DEVELOPMENT.md](../../../docs/DEVELOPMENT.md). *Note*: All paths mentioned in subsequent sections are relative to the top-level directory of the Paxml repository. When working interactively with containers, make sure you navigate to `/opt/paxml` before running any commmands. -## Downloading the SentencePiece Model +# Downloading the SentencePiece Model Pax models require a pretrained SentencePiece model to tokenize the datasets. The SentencePiece model used in the following experiments is `gs://mlperf-llm-public2/vocab/c4_en_301_5Mexp2_spm.model`. This model was trained using [these instructions](https://github.com/sgpyc/training/blob/paxml-llm-draft/large_language_model/paxml/utils/generate_spm.md). Use the following commands to download the tokenizer locally. This should be done _prior_ to launching the container. ``` wget -P c4_sentencepiece https://github.com/nvjax-svc-0/assets/raw/main/sentencepiece_c4/c4_en_301_5Mexp2_spm.model @@ -26,14 +26,14 @@ You can then use the following mount to attach the tokenizer to your container: docker run -v ${PWD}/c4_sentencepiece/c4_en_301_5Mexp2_spm.model:/opt/paxml/vocab ... ``` -## Launching a container +# Launching a container Use the following command to launch a container: ``` docker run -ti --gpus=all --net=host --ipc=host -v :/opt/paxml/datasets -v :/opt/paxml/workspace -v :/opt/paxml/vocab -w /opt/paxml /bin/bash ``` -where `DATASET_PATH` is the path to the Pile or Lambada dataset. If these datasets have not yet been downloaded, they can be downloaded from inside of the container (see [Downloading The Pile and Lambada Datasets](#Downloading-the-pile-and-lambada-datasets) for more). `WORKSPACE_PATH` is the path to the directory where you would like to store any persistent files, and `VOCAB_PATH` is the path to the pretrained SentencePiece model to use during tokenization (see [Downloading the SentencePiece Model](#Downloading-the-sentencepiece-model) for more). +where `DATASET_PATH` is the path to the Pile or Lambada dataset. If these datasets have not yet been downloaded, they can be downloaded from inside of the container (see [Downloading The Pile and Lambada Datasets](#Downloading-the-pile-and-lambada-datasets) for more). `WORKSPACE_PATH` is the path to the directory where you would like to store any persistent files, and `VOCAB_PATH` is the path to the pretrained SentencePiece model to use during tokenization (see [Downloading the SentencePiece Model](#Downloading-the-sentencepiece-model) for more). -## Downloading The Pile and Lambada Datasets +# Downloading The Pile and Lambada Datasets __IMPORTANT UPDATE__: Please be aware that as of October 2023, 'the_pile' dataset is no longer accessible. The team is actively updating our instructions and configurations to incorporate a more recent large language model (LLM) dataset. Additionally, we have provided updated instructions that include methods for using synthetic data, ensuring that our users can continue their work without interruption. Please see the [synthetic dataset](#Synthetic-dataset) section below for more information. The GPT model configs we provide are trained using The Pile dataset and evaluated using the Lambada dataset. The scripts [download_the_pile.py](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/download_the_pile.py) and [download_lambada.py](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/download_lambada.py) will download The Pile and Lambada datasets to the `TFDS_DATA_DIR` enviroment variable. To control the location of the downloaded datasets, use the following command prior to running the download scripts: `export TFDS_DATA_DIR=`. For example, the following commands download the Pile dataset to `/opt/paxml/datasets/`: @@ -44,7 +44,7 @@ python3 paxml/contrib/gpu/scripts_gpu/download_the_pile.py After the data has been successfully downloaded, use the same `TFDS_DATA_DIR` when running experiments. -## Inspecting the Source Code +# Inspecting the Source Code If you would like to inspect Pax's source code (`paxml/*` and `praxis/*`) to learn more about what is being run, you can do so by inspecting the source within the container. Here are some examples: @@ -57,16 +57,16 @@ FILE=paxml/contrib/gpu/scripts_gpu/configs.py docker run --entrypoint="" --rm sh -c 'cat $(python -c "import paxml; print(*paxml.__path__)" 2>/dev/null)/../'$FILE ``` -## Running a Job -Note that when training with The Pile dataset, you must provide the `TFDS_DATA_DIR` as a command-line argument and a `VOCAB_PATH` (the path to a pretrained SentencePiece model) as an environment variable. See the bash scripts below for examples. +# Running a Job +Note that when training with The Pile dataset, you must provide the `TFDS_DATA_DIR` as a command-line argument and a `VOCAB_PATH` (the path to a pretrained SentencePiece model) as an environment variable. See the bash scripts below for examples. -### Quick Runs -#### Interactive: Single Node +## Quick Runs +### Interactive: Single Node See [run_pile_singlenode.sh](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh) for an example of training a 126M parameter model on a single node using The Pile. Once inside of your container, this script can be run interactively using the following command: -``` +``` bash paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh ``` -where `TFDS_DATA_DIR` is the path to The Pile dataset, `VOCAB_PATH` is the path to the pretrained SentencePiece `.model` file, and `LOGDIR` is the relative path of the directory to which to write checkpoints and logging information. `PERCORE_BATCH_SIZE` is the batch size per GPU _prior_ to sharding according to the parallel strategy. See [Customized Runs](#Customized-runs) for more information about this hyperparameter. +where `TFDS_DATA_DIR` is the path to The Pile dataset, `VOCAB_PATH` is the path to the pretrained SentencePiece `.model` file, and `LOGDIR` is the relative path of the directory to which to write checkpoints and logging information. `PERCORE_BATCH_SIZE` is the batch size per GPU _prior_ to sharding according to the parallel strategy. See [Customized Runs](#Customized-runs) for more information about this hyperparameter. For example, to train the 126M model using a percore batch size of 4 on 8 H100 gpus, you can use the following command: ``` @@ -74,21 +74,23 @@ ENABLE_FP8=1 bash paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh /opt/paxm ``` See [run_lambada_singlenode.sh](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/run_lambada_singlenode.sh) for an example of running zero-shot evaluation on the 126M model using the Lambada dataset. Use the following command to run this script: -``` +``` bash paxml/contrib/gpu/scripts_gpu/run_lambada_singlenode.sh ``` `TFDS_DATA_DIR` should contain the path to the Lambada dataset and `LOGDIR` should match the `LOGDIR` from the pretraining run. Note that a pre-trained checkpoint is required in order for this script to run successfully. -#### Multi Node -See [example_slurm_pile.sub](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/scripts/example_slurm_pile.sub) for an example slurm submit file that launches an 8-node training run with a 126 million parameter GPT model. +### Multi Node +The [scripts](scripts) directory provides a number of example submit files for launching the provided models on SLURM+pyxis cluster. For example, [example_slurm_pile.sub](scripts/example_slurm_pile.sub) launches an 8-node training run with a 126 million parameter GPT model. To launch `example_slurm_pile.sub`, run the following command: ``` CONTAINER= BASE_WORKSPACE_DIR= BASE_TFDS_DATA_DIR= BASE_VOCAB_PATH= LOG_DIR_LOCAL= OUTPUT_DIR= PREC=bfloat16 GPUS_PER_NODE=8 PERCORE_BATCH_SIZE=4 ENABLE_FP8= sbatch -N 8 -A -p -J scripts/example_slurm_pile.sub ``` where `BASE_WORKSPACE_DIR`, `BASE_TFDS_DATA_DIR`, and `BASE_VOCAB_PATH` are absolute paths and `LOG_DIR` and `OUTPUT_DIR` are relative to `BASE_WORKSPACE_DIR`. - -### Customized Runs + +Details on the other `.sub` files are provided in the [Configs](#configs) section. + +## Customized Runs Paxml's [main.py](https://github.com/google/paxml/blob/main/paxml/main.py) takes an experiment config as a command-line argument via the `--fdl_config` flag. To control which model to run, swap out the experiment config passed to `main.py`. For example, in [run_pile_multinode.sh](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh), we run the experiment [Pile126M](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/configs.py#L177-L181): ``` ... @@ -96,11 +98,11 @@ Paxml's [main.py](https://github.com/google/paxml/blob/main/paxml/main.py) takes ... ``` -Paxml uses [Fiddle](https://github.com/google/fiddle/tree/main) for configuring hyperparameters. To overwrite an existing hyperparameter from the command line, use the following syntax: +Paxml uses [Fiddle](https://github.com/google/fiddle/tree/main) for configuring hyperparameters. To overwrite an existing hyperparameter from the command line, use the following syntax: ``` --fdl.= ``` -For example, in our `*.sh` scripts, we override the default values of `FPROP_DTYPE`, `ICI_MESH_SHAPE`, and `PERCORE_BATCH_SIZE`. +For example, in our `*.sh` scripts, we override the default values of `FPROP_DTYPE`, `ICI_MESH_SHAPE`, and `PERCORE_BATCH_SIZE`. We provide a list of some of the frequently overridden hyperparameters, and an explanation of each, below: - `ICI_MESH_SHAPE`: This refers to the parallelism strategy used on chips connected by a fast network (e.g. NVLink). `ICI_MESH_SHAPE` typically has 3 dimensions, `[data, fsdp, tensor]`, corresponding to data parallelism (DP), fully-sharded data parallelism (FSDP/ZeRO-3), and tensor parallelism (TP), respectively. For example,to use pure data parallelism, you should set `ICI_MESH_SHAPE` to `[NUM_GPUS, 1, 1]`. @@ -110,7 +112,7 @@ We provide a list of some of the frequently overridden hyperparameters, and an e We provide three "base" configurations in `paxml/contrib/gpu/scripts_gpu/configs.py`. For more information about these configurations and how to run experiments using them, please refer to the [Configs](#Configs) section below. -### Transformer Engine +## Transformer Engine Training using Transformer Engine (TE) with bfloat16 precision is controlled via the environment variable `ENABLE_TE`. TE is enabled by default in the prebuilt container, but if you would like to disable TE, you can do so by flipping the value of `ENABLE_TE` in the container: ``` export ENABLE_TE=0 @@ -120,14 +122,15 @@ FP8 training is controlled via the `ENABLE_FP8` environment variable. To enable ``` ENABLE_FP8=1 bash paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh /opt/paxml/datasets /opt/paxml/vocab bfloat16 8 4 log_dir ``` +Note that transformer engine must be enabled (`ENABLE_TE=1`) in order to train with FP8 using TE). Also, note that packing is currently not supported when using TE. All configs disable packing by default, but beware that if packing is manually enabled, training with TE will error. -Note that packing is currently not supported when using TE. All configs disable packing by default, but beware that if packing is manually enabled, training with TE will error. +## Native FP8 +Rosetta Pax containers also provide support for native FP8 through XLA. Enabling FP8 can be done by adding the following command-line flag to your bash script: `--fdl.USE_FP8=True`. When using native FP8, TE must be disabled. For a detailed explanation of native FP8 support in Pax, as well as a comparison between native FP8 and TE FP8, please refer to the [NATIVE_FP8](../../../docs/NATIVE_FP8.md) documentation. -### Native FP8 -Rosetta Pax containers also provide support for native FP8 through XLA. Enabling FP8 can be done by adding the following command-line flag to `paxml/contrib/gpu/scripts_gpu/run_pile_singlenode.sh`: `--fdl.USE_FP8=True`. When using native FP8, TE must be disabled. For a detailed explanation of native FP8 support in Pax, as well as a comparison between native FP8 and TE FP8, please refer to the [NATIVE_FP8](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/docs/NATIVE_FP8.md) documentation. +## Flash Attention +Flash attention is enabled by default in the given container. Divergence has been observed with the GPT 126M model with flash attention enabled. If you observe divergence when running GPT 126M, it is recommended to disable flash attention. If training with Transformer Engine, you can disable FA using the following environment variable: `NVTE_FUSED_ATTN=0`. If not using TE, FA can be disabled using the following XLA flag: `--set_xla_gpu_enable_cudnn_fmha=False`. -### Flash Attention -As of 2/6/2024, CuDNN flash attention was enabled by default via XLA. Divergence has been observed with the GPT 126M model with flash attention enabled. If you observe divergence when running GPT 126M, you can disable flash attention using the following XLA flag: `--set_xla_gpu_enable_cudnn_fmha=False`. +In addition to improving throughput, enabling flash attention provides a memory savings. Some of the given configurations may run out of memory if flash attention is disabled; if this is the case, try reducing your microbatch size and, if possible, increasing your GPU count. ## XLA Flags The [GPU Performance document](../../../docs/GPU_performance.md) provides a detailed description of the XLA flags that can be set to optimize performance. Additionally, the scripts in `paxml/contrib/gpu/scripts_gpu` automatically set the suggested flags for each model. Please refer to these scripts to find the XLA flags used to reproduce the results documented below. @@ -136,45 +139,64 @@ For the the 126M model, we recommend setting `--xla_gpu_all_reduce_combine_thres ``` BASE_XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false - --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_highest_priority_async_stream=true + --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true + --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_all_reduce_combine_threshold_bytes=33554432 - --xla_gpu_graph_level=0" bash run_pile_multinode.sh ... + --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true" bash run_pile_multinode.sh ... ``` -## Configs -### GPT +# Configs +## GPT We provide three "base" GPT model configurations in `paxml/contrib/gpu/scripts_gpu/configs.py`. The first is a 126 million parameter GPT model. Convergence using The Pile dataset has been verified with this model. The remaining configs are 5 billion and 175 billion parameter models. Both 5B and 175B are provided primarily for benchmarking purposes and been less thoroughly tested for convergence. The tables below describe current performance of the given configs. Experiments were run using NVIDIA DGX A100 80G and H100 80G nodes. Note that Lambada accuracy reported corresponds to the best accuracy seen across the run. Estimated walltime denotes the aproximate time to train each model to completion (i.e. number of days to reach `MAX_STEPS` number of steps as described in `configs.py`). ### A100 Results -| Size | GPU | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequences/Sec | Est. Walltime (days) | Lambada Accuracy (± standard deviation) | Convergence Log | -| ---- | ----- |----- |----- | -- | ---- | -- | ---------| ---------------| ------------------------- | ---------------- |---------------- | -| 126M | A100 80G SXM | BF16 | 64 |64 |1 |1 | 4 | 1877.20 | 0.95 | 0.397 (± 0.012) | [log](https://tensorboard.dev/experiment/RCroDLAUQzGUoudzqD1NmQ/) | -| 5B | A100 80G SXM | BF16 | 256 | 1 |256 |1 | 8 | 465.45 | 3.82 | N/A | | -| 175B | A100 80G SXM | BF16 | 256 |1 |256 |1 | 6 | 18.29 | 72.92 | N/A | | -| 126M | A100 80G SXM | TE BF16 | 64 |64 |1 |1 | 4 | 2512.2 | 0.71 | N/A | | -| 5B | A100 80G SXM | TE BF16 | 256 | 1 |256 |1 | 8 | 586.82 | 3.02 | N/A | | -| 175B | A100 80G SXM | TE BF16 | 256 |1 |256 |1 | 6 | 19.47 | 68.49 | N/A | | +| Size | GPU | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequences/Sec | Est. Walltime (days) | Lambada Accuracy (± standard deviation) | +| ---- | ----- |----- |----- | -- | ---- | -- | ---------| ---------------| ------------------------- | ---------------- | +| 126M | A100 80G SXM | BF16 | 64 |64 |1 |1 | 4 | 2098.16 | 0.85 | 39.7% (± 1.2%) | +| 5B | A100 80G SXM | BF16 | 256 | 1 |256 |1 | 8 | 594.13 | 2.99 | N/A | +| 175B | A100 80G SXM | BF16 | 256 |1 |256 |1 | 6 | * | * | N/A | +| 126M | A100 80G SXM | TE BF16 | 64 |64 |1 |1 | 4 | 2526.72 | 0.70 | N/A | +| 5B | A100 80G SXM | TE BF16 | 256 | 1 |256 |1 | 8 | 718.19 | 2.48 | N/A | +| 175B | A100 80G SXM | TE BF16 | 256 |1 |256 |1 | 6 | 20.44 | 65.24 | N/A | -## H100 Results +\* will be updated once final results have been gathered -| Size | GPU | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequences/Sec | Est. Walltime (days) | Lambada Accuracy (± standard deviation) | Convergence Log | -| ---- | ----- |----- |----- | -- | ---- | -- | ---------| ---------------| ------------------------- | ---------------- |---------------- | -| 126M | H100 80G SXM | TE BF16 | 64 |64 |1 |1 | 4 | 4143.21 | 0.43 | 0.425 (± 0.018) | [log](https://tensorboard.dev/experiment/GgDMwODzQjm9kVc9H6259A/) | -| 5B | H100 80G SXM | TE BF16 | 256 | 1 |256 |1 | 8 | 1066.67 | 1.67 | N/A | | -| 175B | H100 80G SXM | TE BF16 | 256 |1 |256 |1 | 6 | 44.01 | 30.35 | N/A | | -| 5B | H100 80G SXM | TE FP8 | 256 | 1 |256 |1 | 8 | 1288.05 | 1.38 | N/A | [log](https://tensorboard.dev/experiment/i5kiGeQpRRapswa68RkYHQ/) | -| 175B | H100 80G SXM | TE FP8 | 256 |1 |256 |1 | 6 | 65.64 | 20.33 | N/A | [log](https://tensorboard.dev/experiment/HvpU324wQYarwgvd9P3Uew/) | +### H100 Results +| Size | GPU | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequences/Sec | Est. Walltime (days) | Lambada Accuracy (± standard deviation) | +| ---- | ----- |----- |----- | -- | ---- | -- | ---------| ---------------| ------------------------- | ---------------- | +| 126M | H100 80G SXM | TE BF16 | 64 |64 |1 |1 | 4 | 4709.12 | 0.38 | 42.5% (± 1.8%) | +| 5B | H100 80G SXM | TE BF16 | 256 | 1 |256 |1 | 8 | 1657.24 | 1.07 | N/A | +| 175B | H100 80G SXM | TE BF16 | 256 |1 |256 |1 | 6 | 51.00 | 26.15 | N/A | +| 5B | H100 80G SXM | TE FP8 | 256 | 1 |256 |1 | 8 | 2374.66 | 0.749 | N/A | +| 175B | H100 80G SXM | TE FP8 | 256 |1 |256 |1 | 6 | 84.45 | 15.79 | N/A | -*Note*: Estimated walltime is computed assuming full throughput continuously. In practice, true walltime may be greater due to compilation overheads, interleaved evaluation, and checkpointing. A number of the linked convergence runs were completed using older software; thus, throughput reported in the linked logs may not match current results. The most up-to-date throughput numbers are reported in the table. + +*Note*: Estimated walltime is computed assuming full throughput continuously. In practice, true walltime may be greater due to compilation overheads, interleaved evaluation, and checkpointing. 126M performance numbers were gathered _without_ flash attention (due to known convergence issues with flash attention, see [Known Issues](#Known-issues) for more). The other model sizes enable flash attention. 5B FP8 was trained for 75,000 steps at a global batch size of 2048 and a sequence length of 2048, amounting to around 300 billion consumed tokens. 175B FP8 was trained for a total of around 1,000 steps at a global batch size of 1536 and a sequence length of 2048, amounting to around 3.14 billion consumed tokens. 175B was trained using the [C4 dataset](https://github.com/mlcommons/training/tree/master/large_language_model/paxml#2-dataset) and restores from an [initial MLPerf checkpoint](https://github.com/mlcommons/training/tree/master/large_language_model/paxml#initial-checkpoint). 126M and 5B were both trained using the Pile. -### LLaMA -We also provide LLaMA-2 7B, 13B and 70B configs. These configs are variants of the [LLaMA configs](https://github.com/google/saxml/blob/main/saxml/server/pax/lm/params/lm_cloud.py) provided by Saxml and have been validated on the [BoolQ](https://github.com/google-research-datasets/boolean-questions) dataset. The table below reports BoolQ accuracy for each model. +## Mixture of Experts +We provide configs for two GLaM models. GLaM is a class of mixture of experts models with every other transformer layer replaced with a MoE layer with top-2 routing. The model sizes we provide are 126M/64E (126M base dense model, 64 experts, ~1.9B parameters) and 64B/64E (~1.14T parameters). Convergence has been validated on 126M/64E. Convergence results are outlined below. + +| Model | Num. params | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequence length | Lambada Accuracy (fixed compute) | Lambada Accuracy (fixed steps) | +| --------- |------------ | --------- | ----- | -- | ---- | -- | -------- | -------------- | ------ | ------ | +| 126M/64E | 1.9B | BF16 | 64 | 1 | 64 | 1 | 8 | 2048 | 46.15% | 49.21% | +| 64B/64E | 1.14T | BF16 | 512 | 1 | 64 | 8 | 4 | 2048 | N/A | N/A | + +"Fixed compute" refers to the lambada accuracy given the same compute budget as GPT 126M dense (measured on H100), and "fixed steps" refers to the lambada accuracy given the same number of training steps as 126M dense. + +The script `paxml/contrib/gpu/scripts_gpu/run_base_config_multinode.sh` can be used to run these GLaM configurations. See the [Running an Experiment with Base Configs](#Running-an-experiment-with-base-configs) section for more information about how to lauch a slurm job using this script. + +_Note_: The GLaM configs provided currently do not have support for Transformer Engine. We are actively working on this and will update the configs as TE support becomes available. + +## LLaMA +We also provide LLaMA-2 7B, 13B and 70B configs. These configs are variants of the [LLaMA configs](https://github.com/google/saxml/blob/main/saxml/server/pax/lm/params/lm_cloud.py) provided by Saxml and have been validated on the [BoolQ](https://github.com/google-research-datasets/boolean-questions) dataset. The table below reports BoolQ zero-shot accuracy for each model. + +### Zero-shot Accuracy | Size | Precision | #GPUs | DP | FSDP | TP | BS / GPU | BoolQ Accuracy | | ---- |---------- | ----- | -- | ---- | -- | -------- | -------------- | @@ -182,61 +204,128 @@ We also provide LLaMA-2 7B, 13B and 70B configs. These configs are variants of t | 13B | BF16 | 8 | 1 | 8 | 1 | 8 | 82.99% | | 70B | BF16 | 16 | 1 | 16 | 1 | 4 | 85.08% | -Saxml provides a [script](https://github.com/google/saxml/blob/f3efdafed400d03be22efdb39a006f1420460d9f/saxml/tools/convert_llama_ckpt.py) to convert Meta's LLaMA checkpoints to Paxml format. This script can be run inside of any JAX-Toolbox pax container. First, apply for access and download the Meta checkpoints and LLaMA tokenizer using [this link](https://llama.meta.com/llama-downloads/). Then, mount the Meta checkpoints to the container and run the following commands to convert the checkpoint: +### Fine-tuning +LLaMA fine-tuning is supported via full supervised fine-tuning (SFT) and LoRA parameter-efficient fine-tuning. Performance and convergence has been tested on LLaMA-2 7B, and results are reported below. + +#### SFT Results +| Size | GPU | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequence Length | Sequences/Sec | BoolQ Accuracy (± standard deviation) | +| ---- | ----- |----- |----- | -- | ---- | -- | ---------| ---------------| ------------------------- | ---------------- | +| 7B | H100 80G SXM | BF16 | 16 | 1 |16 |1 | 2 | 4096 | 43.24 | 88.7% (± 0.12%) | +| 7B | H100 80G SXM | TE BF16 | 16 |1 |16 |1 | 2 | 4096 | 53.69 | 88.2% (± 0.17%) | + +#### LoRA Results + +Default LoRA parameters for all runs: +- LORA_RANK = 32 +- LORA_TARGET_LAYERS = all +- TRAIN_STEPS = 600 + +| Size | GPU | Precision | #GPUs | DP | FSDP | TP | BS / GPU | Sequence Length | Total Sequences | Sequences/Sec | BoolQ Accuracy (± standard deviation) | +| ---- | ----- |----- |----- | -- | ---- | -- | ---------| ---------------| ---------------| ------------------------- | ---------------- | +| 7B | H100 80G SXM | TE BF16 | 16 |1 |16 |1 | 2 | 4096 | 19,200 | 63.2 | 88.8933 (± 0.146) % | +| 7B | H100 80G SXM | TE BF16 | 16 |1 |16 |1 | 1 | 4096 | 9,600 | 56 | 88.52 (± 0.198) % | +| 7B | H100 80G SXM | BF16 | 16 |1 |16 |1 | 2 | 4096 | 19,200 | 43.8 | 88.57 (± 0.2275) % | + +### Running LLaMA Evaluation/Fine-tuning + +Saxml provides a [script](https://github.com/google/saxml/blob/f3efdafed400d03be22efdb39a006f1420460d9f/saxml/tools/convert_llama_ckpt.py) to convert Meta's LLaMA checkpoints to Paxml format for zero-shot evaluation and fine-tuning. This script can be run inside of any JAX-Toolbox pax container. First, apply for access and download the Meta checkpoints and LLaMA tokenizer using [this link](https://llama.meta.com/llama-downloads/). Then, mount the Meta checkpoints to the container and run the following commands to convert the checkpoint: ``` pip install pytorch ## loading meta checkpoints requires pytorch wget https://raw.githubusercontent.com/google/saxml/f3efdafed400d03be22efdb39a006f1420460d9f/saxml/tools/convert_llama_ckpt.py python3 -m convert_llama_ckpt --base-model-path --pax-model-path --model-size <7b, 13b, or 70b> ``` -The script [download_boolq.py](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/download_boolq.py) downloads the BoolQ dataset to the `TFDS_DATA_DIR` (see [Downloading the Pile and Lambada Datasets](#Downloading-the-pile-and-lambada-datasets) for more). Once BoolQ has been downloaded, the script [example_slurm_llama.sub](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/scripts/example_slurm_llama.sub) can be used to reproduce the results reported in the table. Launch the script using the following command: +If you'd like to run LLaMA with transformer engine, the [Pax <--> TE checkpoint converter](../../../utils/te_pax_t5x_ckpt_converter) can be used to produce a TE-compatible checkpoint using the following command: +``` +python converter/main.py \ + --input-path=/your_path_to_src_ckpt \ + --output-path=/your_path_to_output_ckpt \ + --fw=pax \ + --direction=fw2te \ + --num-of-layer= \ + --num-of-head= \ + --head-dim= \ + --mlp-intermediate-dim= \ + --skip-bias \ + --weight-only \ + --use-gated-activations +``` +if converting the 70B checkpoint, the following additional arguments should be passed to the converter: +``` + --num-gqa-groups=8 \ + --pax-split-qkv \ + --te-qkv-layout=kv_packed +``` +Please refer to the checkpoint converter [readme](../../../utils/te_pax_t5x_ckpt_converter#readme) for more detailed instructions. + +The script [download_boolq.py](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/download_boolq.py) downloads the BoolQ dataset to the `TFDS_DATA_DIR` (see [Downloading the Pile and Lambada Datasets](#Downloading-the-pile-and-lambada-datasets) for more). Once BoolQ has been downloaded, the script [example_slurm_llama.sub](scripts/example_slurm_llama.sub) can be used to reproduce the results reported in the tables. The script calls `paxml/contrib/gpu/scripts_gpu/run_llama_boolq.sh`, which is configured to run the 7B model by default. Please inspect `run_llama_boolq.sh` in your container to see the arguments that can be overwritten if interested in running other model sizes. Launch `example_slurm_llama.sub` using the following command: + ``` -CONTAINER= BASE_WORKSPACE_DIR= BASE_TFDS_DATA_DIR= BASE_VOCAB_PATH= BASE_CHECKPOINT_DIR= OUTPUT_DIR= PREC=bfloat16 GPUS_PER_NODE=8 PERCORE_BATCH_SIZE= CONFIG= sbatch -N -A -p -J scripts/example_slurm_llama.sub +CONTAINER= BASE_WORKSPACE_DIR= BASE_TFDS_DATA_DIR= BASE_VOCAB_PATH= OUTPUT_DIR= EVAL_ONLY= USE_LORA= BASE_CHECKPOINT_RESTORE_PATH= LOG_DIR_LOCAL= CONFIG= ENABLE_TE= sbatch -N -A -p -J scripts/example_slurm_llama.sub ``` -`CONFIG` should be one of `LLaMA7B`, `LLaMA13B`, or `LLaMA70B` and `PERCORE_BATCH_SIZE` and `NUM_NODES` should match with the table above. +`CONFIG` should be one of `LLaMA7B`, `LLaMA13B`, or `LLaMA70B`. `EVAL_ONLY` is a boolean indicating whether to run zero-shot evaluation (`EVAL_ONLY=1`) or fine-tuning. `CHECKPOINT_RESTORE_PATH` refers to the path to the pretrained checkpoint to restore from. The pretrained checkpoint is expected to have the following directory structure: `/checkpoints/checkpoint_`. In order for the checkpoint restore to work correctly, `CHECKPOINT_RESTORE_PATH` should be ``. + +The same script can also be used to fine tune LLaMA models using [LoRA](https://arxiv.org/abs/2106.09685). The environment variables that configure LoRA are specified below: +- USE_LORA: Specifies whether LoRA will be used for finetuning. Default value is 0. Set to 1 if you want to enable LoRA. +- LORA_RANK: Rank used for the LoRA weight matrices. Default value is 32. +- LORA_TARGET_LAYERS: Specifies which layers to target for LoRA. Default value is 'all' which targets all linear layers. Acceptable values are "all", "attention", "mlp" where "all" targets all linear layers; "attention" targets q, k, v and out projection; "mlp" targets all MLP layers. -_Note_: The given LLaMA configs currently do not have support for Transformer Engine. We are actively working on this and will update the configs as TE support becomes available. +For example, the following command will run LoRA fine-tuning on the LLaMA-2 7B model: -### Running an Experiment with Base Configs -To run an experiment with any base model configuration with the default parallel strategy reported in the table, copy [run_pile_multinode.sh](https://github.com/google/paxml/blob/main/paxml/contrib/gpu/scripts_gpu/run_pile_multinode.sh) to your workspace and make the following modifications: replace `--fdl_config=paxml.contrib.gpu.scripts_gpu.configs.Pile126M` with the experiment you are interested in running (e.g. `paxml.contrib.gpu.scripts_gpu.configs.GPT5B` or `paxml.contrib.gpu.scripts_gpu.configs.GPT175B`) and remove `--fdl.ICI_MESH_SHAPE="[${NUM_GPUS}, 1, 1]"` and `--fdl.DCN_MESH_SHAPE="[${SLURM_JOB_NUM_NODES}, 1, 1]"`. The resulting bash script (call it `run_my_model_multinode.sh`) can be passed into `example_slurm_pile.sub` using the following command. This command presumes that `run_my_model_multinode.sh` lives in `BASE_WORKSPACE_DIR`. ``` -BASE_SCRIPT=run_my_model_multinode.sh CONTAINER= BASE_WORKSPACE_DIR= BASE_TFDS_DATA_DIR= BASE_VOCAB_PATH= LOG_DIR_LOCAL= OUTPUT_DIR= PREC= GPUS_PER_NODE= PERCORE_BATCH_SIZE= ENABLE_FP8= sbatch -N -A -p -J scripts/example_slurm_pile.sub +CONTAINER= BASE_WORKSPACE_DIR=$PWD BASE_TFDS_DATA_DIR= BASE_VOCAB_PATH= OUTPUT_DIR=lora_stdout EVAL_ONLY=0 USE_LORA=1 BASE_CHECKPOINT_RESTORE_PATH= LOG_DIR_LOCAL=7b_log_dir CONFIG=LLaMA7B ENABLE_TE=1 sbatch -N 2 -A -p -J scripts/example_slurm_llama.sub ``` -Here, it is assumed that you are running with the number of nodes reported in the table. If using a different node count, scale `DCN_MESH_SHAPE` accordingly. For example, the default value of `DCN_MESH_SHAPE` for `paxml.contrib.gpu.scripts_gpu.configs.GPT5B` is `[1,32,1]`. If running on 16 nodes, adjust `DCN_MESH_SHAPE` as follows: + +_Note_: The given LLaMA configs currently do not support FP8 training via Transformer Engine. We are actively working on this and will update the configs as TE support becomes available. + +## Running an Experiment with Base Configs +The `run_base_config_multinode.sh` script is provided to run any of the base configs provided in `paxml/contrib/gpu/scripts_gpu/configs.py` out of the box. [scripts/launch_base_script.sub](scripts/launch_base_script.sub) uses this script to train a model on a slurm cluster. Launch this script using the following command: +``` +CONTAINER= CONFIG= BASE_WORKSPACE_DIR= BASE_TFDS_DATA_DIR= BASE_VOCAB_PATH= LOG_DIR_LOCAL= OUTPUT_DIR= PREC= GPUS_PER_NODE= ENABLE_TE= ENABLE_FP8= sbatch -N -A -p -J scripts/launch_base_script.sub +``` +where `CONFIG` is the name of the config from `paxml/contrib/gpu/scripts_gpu/configs.py`. Here, it is assumed that you are running with the number of nodes reported in the table. If using a different node count, scale `DCN_MESH_SHAPE` accordingly. For example, the default value of `DCN_MESH_SHAPE` for `GPT5B` is `[1,32,1]`. If running on 16 nodes, adjust `DCN_MESH_SHAPE` in your bash script as follows: ``` --fdl.DCN_MESH_SHAPE=[1,16,1] ``` -#### Synthetic Dataset -We also provide 126M, 5B and 175B configurations with a dummy dataset for quick benchmarking. The script `paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh` benchmarks any of the given base models using the synthetic dataset. [scripts/example_slurm_synthetic.sub](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax/scripts/example_slurm_synthetic.sub) can be used to launch this script on a slurm cluster. This script can be launched using the following command: +### Synthetic Dataset +We also provide GPT 126M, 5B and 175B configurations with a dummy dataset for quick benchmarking. The script `run_base_config_multinode.sh` can also be used to benchmark any of the given base models using the synthetic dataset. [scripts/launch_base_script.sub](scripts/launch_base_script.sub) can be used to launch this script on a slurm cluster. When training using a dummy dataset, it is not required to pass in a `BASE_VOCAB_PATH` or `TFDS_DATA_DIR`: ``` -BASE_WORKSPACE_DIR= CONFIG=Synthetic<126M, 5B, 175B> OUTPUT_DIR= PREC=bfloat16 ENABLE_TE= ENABLE_FP8= GPUS_PER_NODE=8 PERCORE_BATCH_SIZE= LOG_DIR_LOCAL= sbatch -N -A -p -J scripts/example_slurm_synthetic.sub +BASE_WORKSPACE_DIR= CONFIG=Synthetic<126M, 5B, 175B> OUTPUT_DIR= PREC=bfloat16 ENABLE_TE= ENABLE_FP8= LOG_DIR_LOCAL= sbatch -N -A -p -J -t scripts/launch_base_script.sub ``` For example, the following command benchmarks the 5B model on 32 nodes with TE BF16 using the synthetic dataset: ``` -BASE_WORKSPACE_DIR= CONFIG=Synthetic5B OUTPUT_DIR=output_synthetic_5b PREC=bfloat16 ENABLE_TE=1 ENABLE_FP8=0 GPUS_PER_NODE=8 PERCORE_BATCH_SIZE=8 LOG_DIR_LOCAL=log_dir_synthetic_5b sbatch -N 32 -A -p -J scripts/example_slurm_synthetic.sub +BASE_WORKSPACE_DIR= CONFIG=Synthetic5B OUTPUT_DIR=output_synthetic_5b PREC=bfloat16 ENABLE_TE=1 ENABLE_FP8=0 LOG_DIR_LOCAL=log_dir_synthetic_5b sbatch -N 32 -A -p -J scripts/launch_base_config.sub ``` Note that with models that are particularly dataloading-bottlenecked (e.g. smaller models, such as 126M), the throughput observed using the synthetic dataset may be higher than the throughput observed when training on a real dataset. -## Known Issues -* Pipeline parallelism is not supported with NVIDIA Transformer Engine enabled in the Paxml container. -* The Paxml nightlies disable `NCCL_NVLS_ENABLE=0` ([doc](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. -* The release container has a known XLA bug which affects single-process training in some cases. This bug has been fixed in newer XLA versions. If running into issues with single-process training, try using a Pax nightly container after 10/3. You can also try cherry-picking [this commit](https://github.com/openxla/xla/commit/aa8e7340cb319b9419a097155874bf105da05e1d) in the tested container. -* Infrequent hangs have been observed in multinode settings. Setting `CUDA_MODULE_LOADING=EAGER` helps with these hangs. This environment variable is set by default in `nvcr.io/nvidia/jax:23.10-paxml-py3`. -* We currently see unexpected convergence behavior when dropout is used with Transformer Engine. Default configs do not enable dropout within transformer layers and thus should be unaffected by this bug, but users may encounter this bug if manually enabling dropout in their models. +# Known Issues +* Divergence has been observed with the GPT 126M model with flash attention enabled. If you observe divergence when running GPT 126M, it is recommended to disable flash attention. +* There is a known bug with cudnn flash attention that can cause divergence when using flash attention _without_ TE. We recommend running all models with TE enabled, but if you would like to disable TE, and you observe unexpected divergence, try disabling flash attention using the following XLA flag: `--xla_gpu_enable_cudnn_fmha=false` +* TE is currently not supported with GLaM models. Future releases will include TE support with GLaM. +* The provided LLaMA configs do not support TE FP8 for fine-tuning. Future releases will add FP8 support. +* The Paxml containers disable `NCCL_NVLS_ENABLE=0` ([doc](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. +* LoRA without TE is currently not supported for models using `CombinedQKVProjection` where `input_dim != num_heads * dims_per_head`. Fix for this issue will be available in the nightlies soon. +* Setting `NVTE_FUSED_ATTN=1` when using single-processing (one process per GPU) results in a hang. This is currently being investigated. It is recommended to disable TE flash attention (`NVTE_FUSED_ATTN=0`) when using single-processing. +* A bug was introduced to TE on 4/24/2024 that affects LLaMA convergence. This bug is currently being investigated and will be fixed in an upcoming nightly. + +# Changelog +## 4/26/2024 +- Added support for LLaMA SFT and LoRA fine-tuning (BF16 and TE BF16) +- Added support for MoE models: GLaM 126M and GLaM 64B (BF16) +- Enabled TE flash attention by default -## Changelog -### 10/26/2023 +## 10/26/2023 - Enabled BF16 Transformer Engine by default - Added FP8 Transformer Engine support - Updated 5B config to disable dropout in transformer layers - bfloat16 performance - 126M performance is 6% higher than 8/29, bringing the overall regression with respect to 7/11 to around 10%. We will continue to improve 126M performance in future releases. -### 8/29/2023 +## 8/29/2023 - Added bfloat16 Transformer Engine support - Disabled packing by default in all base configurations for TE compatibility - Updated 5B config to use fully sharded data parallel (FSDP) @@ -245,7 +334,7 @@ Note that with models that are particularly dataloading-bottlenecked (e.g. small - 3% speedup - 5B - 4.5% speedup - 175B -### 7/11/2023 +## 7/11/2023 - Updated 175B config. 175B now trained on 32 nodes using fully sharded data parallel (FSDP) - A100 perf gains - 22% speedup - 126M diff --git a/rosetta/rosetta/projects/pax/scripts/example_slurm_llama.sub b/rosetta/rosetta/projects/pax/scripts/example_slurm_llama.sub index a3c4a63c5..df38bf990 100644 --- a/rosetta/rosetta/projects/pax/scripts/example_slurm_llama.sub +++ b/rosetta/rosetta/projects/pax/scripts/example_slurm_llama.sub @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH -A example # slurm account #SBATCH -p partition # slurm partition name -#SBATCH -N 1 # number of nodes -#SBATCH -t 00:20:00 # wall time +#SBATCH -N 2 # number of nodes +#SBATCH -t 01:00:00 # wall time #SBATCH -J "paxml:test" # job name #SBATCH --exclusive # exclusive node access #SBATCH --mem=0 # all mem avail @@ -28,13 +28,13 @@ set -eux # File system and volume glue code #------------------------------------------------------------------------------- -CONTAINER="${CONTAINER:-ghcr.io/nvidia/jax:pax-llama2-2024-02-07}" +CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:24.04-paxml-py3}" # << CHANGE ! >> BASE_WORKSPACE_DIR=${BASE_WORKSPACE_DIR} ## location to write logs and checkpoints to BASE_TFDS_DATA_DIR=${BASE_TFDS_DATA_DIR} BASE_VOCAB_PATH=${BASE_VOCAB_PATH} -BASE_CHECKPOINT_DIR=${BASE_CHECKPOINT_DIR} +BASE_CHECKPOINT_RESTORE_PATH=${BASE_CHECKPOINT_RESTORE_PATH} PAXML_DIR=${PAXML_DIR:-/opt/paxml} # Default env variables for paths required by pax training scripts @@ -44,7 +44,7 @@ GPT_VOCAB_PATH=/mnt/vocab CHECKPOINT_DIR=/opt/paxml/workspace/llama-checkpoint # Add the pax/JAX specific mounts -MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_VOCAB_PATH:$GPT_VOCAB_PATH,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_CHECKPOINT_DIR:$CHECKPOINT_DIR" +MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_VOCAB_PATH:$GPT_VOCAB_PATH,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_CHECKPOINT_RESTORE_PATH:$CHECKPOINT_DIR" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR @@ -53,16 +53,26 @@ EXPORTS="--export=ALL" #------------------------------------------------------------------------------- OUTPUT_DIR=${OUTPUT_DIR:-"output"} -PREC=${PREC} -GPUS_PER_NODE=${GPUS_PER_NODE} -PERCORE_BATCH_SIZE=${PERCORE_BATCH_SIZE} CONFIG=${CONFIG:-LLaMA7B} +LOG_DIR_LOCAL=${LOG_DIR_LOCAL} +USE_LORA=${USE_LORA:-False} +ENABLE_TE=${ENABLE_TE:-1} +EVAL_ONLY=${EVAL_ONLY:-0} cmd="$(cat <> -CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:23.10-paxml-py3}" +CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:24.04-paxml-py3}" # << CHANGE ! >> BASE_WORKSPACE_DIR=${BASE_WORKSPACE_DIR} ## location to write logs and checkpoints to diff --git a/rosetta/rosetta/projects/pax/scripts/example_slurm_synthetic.sub b/rosetta/rosetta/projects/pax/scripts/launch_base_script.sub similarity index 67% rename from rosetta/rosetta/projects/pax/scripts/example_slurm_synthetic.sub rename to rosetta/rosetta/projects/pax/scripts/launch_base_script.sub index 8674af4a4..df38bf990 100644 --- a/rosetta/rosetta/projects/pax/scripts/example_slurm_synthetic.sub +++ b/rosetta/rosetta/projects/pax/scripts/launch_base_script.sub @@ -1,13 +1,13 @@ #!/bin/bash #SBATCH -A example # slurm account #SBATCH -p partition # slurm partition name -#SBATCH -N 8 # number of nodes -#SBATCH -t 00:30:00 # wall time +#SBATCH -N 2 # number of nodes +#SBATCH -t 01:00:00 # wall time #SBATCH -J "paxml:test" # job name #SBATCH --exclusive # exclusive node access #SBATCH --mem=0 # all mem avail #SBATCH --ntasks-per-node=8 # n tasks per machine (one task per gpu) -#SBATCH --overcommit +#SBATCH --overcommit #SBATCH --dependency=singleton # only run one instance at a time set -eux @@ -28,20 +28,23 @@ set -eux # File system and volume glue code #------------------------------------------------------------------------------- -# << CHANGE ! >> -CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:23.10-paxml-py3}" +CONTAINER="${CONTAINER:-nvcr.io/nvidia/jax:24.04-paxml-py3}" # << CHANGE ! >> BASE_WORKSPACE_DIR=${BASE_WORKSPACE_DIR} ## location to write logs and checkpoints to +BASE_TFDS_DATA_DIR=${BASE_TFDS_DATA_DIR} +BASE_VOCAB_PATH=${BASE_VOCAB_PATH} +BASE_CHECKPOINT_RESTORE_PATH=${BASE_CHECKPOINT_RESTORE_PATH} PAXML_DIR=${PAXML_DIR:-/opt/paxml} -ENABLE_TE=${ENABLE_TE:-1} -ENABLE_FP8=${ENABLE_FP8:-0} # Default env variables for paths required by pax training scripts WORKSPACE_DIR=/opt/paxml/workspace +TFDS_DATA_DIR=/mnt/datasets +GPT_VOCAB_PATH=/mnt/vocab +CHECKPOINT_DIR=/opt/paxml/workspace/llama-checkpoint # Add the pax/JAX specific mounts -MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR" +MOUNTS="--container-mounts=$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_VOCAB_PATH:$GPT_VOCAB_PATH,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_CHECKPOINT_RESTORE_PATH:$CHECKPOINT_DIR" # Make directories that may not exist mkdir -p $BASE_WORKSPACE_DIR @@ -49,26 +52,27 @@ mkdir -p $BASE_WORKSPACE_DIR EXPORTS="--export=ALL" #------------------------------------------------------------------------------- -CONFIG=${CONFIG:-"Synthetic126M"} OUTPUT_DIR=${OUTPUT_DIR:-"output"} -PREC=${PREC} -GPUS_PER_NODE=${GPUS_PER_NODE} -PERCORE_BATCH_SIZE=${PERCORE_BATCH_SIZE} +CONFIG=${CONFIG:-LLaMA7B} LOG_DIR_LOCAL=${LOG_DIR_LOCAL} - -if [[ -z "${BASE_SCRIPT:-}" ]]; then - export BASE_SCRIPT="${PAXML_DIR}/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh" - echo "Using default BASE_SCRIPT=$BASE_SCRIPT" -else - export BASE_SCRIPT="${WORKSPACE_DIR}/${BASE_SCRIPT}" - echo "Using custom BASE_SCRIPT=$BASE_SCRIPT" -fi +USE_LORA=${USE_LORA:-False} +ENABLE_TE=${ENABLE_TE:-1} +EVAL_ONLY=${EVAL_ONLY:-0} cmd="$(cat < to be set. (default: 'qkv_packed') ``` ### Usage Examples @@ -71,7 +86,7 @@ python converter/main.py \ --mlp-intermediate-dim=1024 ``` -2. Pax -> TE (Not Repeat): +4. Pax -> TE (Not Repeat): ```bash python converter/main.py \ --input-path=/your_path_to_src_ckpt \ @@ -84,6 +99,26 @@ python converter/main.py \ --mlp-intermediate-dim=1024 ``` +5. Pax -> TE (LLaMa2-70b): +```bash +python converter/main.py \ + --input-path=/your_path_to_src_ckpt \ + --output-path=/your_path_to_output_ckpt \ + --fw=pax \ + --direction=fw2te \ + --num-of-layer=80 \ + --num-of-head=64 \ + --num-gqa-groups=8 \ + --head-dim=128 \ + --mlp-intermediate-dim=28672 \ + --kernel-chunk-size=512 \ + --skip-bias \ + --weight-only \ + --use-gated-activations \ + --pax-split-qkv \ + --te-qkv-layout=kv_packed +``` + #### T5X 1. TE/FusedQKV -> T5X: ```bash @@ -154,7 +189,7 @@ restoring it to keep training. #### The folder structure of CKPT by Pax and T5X If you would like to run the converted CKPTs with frameworks, you may expect the converted CKPTs have the same folder -structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the +structure with CKPTs stored by frameworks. In this case, you could set `--output-path` to be the same stucture as the CKPTs from frameworks, and no need to pre-generate folders, since it would be generated when needed. For Pax, you could set `--output-path` be like ` /${your_path_to_output}/checkpoints/checkpoint_${step}`. For T5X, you could set `--output-path` be like `/${your_path_to_output}/checkpoint_${step}`. diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py index 31d52b6d7..9a9063bb7 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/main.py @@ -27,6 +27,9 @@ FW2TE = 'fw2te' TE2FW = 'te2fw' +QKV_PACKED='qkv_packed' +KV_PACKED='kv_packed' + # Key = (Direction, isRepeat) PAX_CONVERT_HELPER_DICT = { (FW2TE, False): Pax2TEConvertHelper, @@ -76,6 +79,12 @@ def parse_args(): type=int, required=True, help="the number of head of multi-head attention of the given source checkpoint.") + parser.add_argument( + '--num-gqa-groups', + type=int, + default=None, + help="the number of GQA groups (key-value heads) of the given source checkpoint. " + + "This must be set for --te-qkv-layout=kv_packed.") parser.add_argument( '--head-dim', type=int, @@ -104,32 +113,64 @@ def parse_args(): default=False, help="indicate if the source checkpoint only includes weights.") + parser.add_argument('--skip-bias', + action="store_true", + default=False, + help="indicate whether the source checkpoint has biases.") + parser.add_argument('--skip-ln', action="store_true", default=False, help="indicate if skip the conversion for LayerNorm.") + parser.add_argument( + '--use-gated-activations', + action="store_true", + default=False, + help="indicate if the model uses a gated activation function.") + parser.add_argument('--pax-repeat', action="store_true", default=False, help="indicate if the source Pax checkpoint enables Repeat.") + + parser.add_argument( + '--pax-split-qkv', + action="store_true", + default=False, + help="indicate if the source Pax checkpoint has split QKV parameters.") + parser.add_argument( '--t5x-fuse-qkv', action="store_true", default=False, help="indicate if the source T5X checkpoint enables fused_qkv_params of TE.") + parser.add_argument( + '--te-qkv-layout', + type=str, + choices=(QKV_PACKED, KV_PACKED), + default=QKV_PACKED, + help="indicate the QKV layout of the target TE checkpoint. " + + "--te-qkv-packed=kv_packed is supported only with --pax-split-qkv, and " + + "requires --num-gqa-groups to be set.") + args = parser.parse_args() if args.fw == T5X: assert args.embed_dim is not None + elif args.te_qkv_layout == KV_PACKED: + assert args.pax_split_qkv + assert args.num_gqa_groups is not None + return args def get_convert_helper(args): - model_config = ModelConfig(args.num_of_layer, args.embed_dim, args.num_of_head, args.head_dim, - args.mlp_intermediate_dim, args.kernel_chunk_size) + model_config = ModelConfig(args.num_of_layer, args.embed_dim, args.num_of_head, + args.num_gqa_groups, args.head_dim, args.mlp_intermediate_dim, + args.kernel_chunk_size) convert_helper_cls = None @@ -141,7 +182,9 @@ def get_convert_helper(args): assert convert_helper_cls is not None, "Not Supported." return convert_helper_cls(args.input_path, args.output_path, model_config, - args.weight_only, args.skip_ln) + args.weight_only, args.skip_bias, args.skip_ln, + args.use_gated_activations, args.pax_split_qkv, + args.te_qkv_layout) if __name__ == "__main__": diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py index 3596c0e0a..7737992ad 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/paxml_converters.py @@ -33,72 +33,39 @@ def _generate_ckpt_map(self): ckpt_map = {} num_of_head = self.model_config.num_of_head + num_gqa_groups = self.model_config.num_gqa_groups head_dim = self.model_config.head_dim hidden_dim = num_of_head * head_dim mlp_intermediate_dim = self.model_config.mlp_intermediate_dim for i in range(self.model_config.num_of_layer): ckpt_map.update({ - f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.bias.b": - self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_bias", - (mlp_intermediate_dim,), None, lambda x: jnp.reshape(x, (1, *x.shape))), f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.linear.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", (hidden_dim, mlp_intermediate_dim), 0, - lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), - f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.bias.b": + extra_src_paths = [f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1_gate.linear.w"], + stack_dim = -2) if self.use_gated_act else \ self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_bias", - None, - None, - just_copy=True), + f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_kernel", + (hidden_dim, mlp_intermediate_dim), 0, + lambda x: jnp.reshape(x, (*x.shape[:-1], 1, x.shape[-1]))), f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.linear.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_kernel", (mlp_intermediate_dim, hidden_dim), 1), - f"lm.transformer.x_layers_{i}.ff_layer.layer_norm.bias": - self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.ln_bias", - None, - None, - just_copy=True), f"lm.transformer.x_layers_{i}.ff_layer.layer_norm.scale": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.scale", None, None, just_copy=True), - f"lm.transformer.x_layers_{i}.layer_norm.bias": - self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.ln_bias", - None, - None, - just_copy=True), f"lm.transformer.x_layers_{i}.layer_norm.scale": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.scale", None, None, just_copy=True), - f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.b": - self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.bias", - (3, num_of_head, head_dim), None, - lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1]))), - f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.w": - self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.kernel", - (3, hidden_dim, num_of_head, head_dim), 0, - lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), - lambda x: jnp.transpose(x, (1, 0, 2))), - f"lm.transformer.x_layers_{i}.self_attention.post.b": - self._get_convert_pkg( - f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.bias", - None, - None, - just_copy=True), f"lm.transformer.x_layers_{i}.self_attention.post.w": self._get_convert_pkg( f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.kernel", @@ -107,6 +74,141 @@ def _generate_ckpt_map(self): lambda x: jnp.transpose(x, (1, 0))) }) + # Conversion map for QKV + if self.te_qkv_layout == 'qkv_packed': + if self.pax_split_qkv: + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.self_attention.query.w": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.kernel", + (hidden_dim, num_of_head, head_dim), 0, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), + extra_src_paths = [f"lm.transformer.x_layers_{i}.self_attention.key.w", + f"lm.transformer.x_layers_{i}.self_attention.value.w"], + stack_dim = -3) + }) + else: + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.w": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.kernel", + (3, hidden_dim, num_of_head, head_dim), 0, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), + lambda x: jnp.transpose(x, (1, 0, 2))) + }) + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.layer_norm.scale": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.scale", + None, + None, + just_copy=True), + }) + + elif self.te_qkv_layout == 'kv_packed': + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.self_attention.query.w": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.query.kernel", + (hidden_dim, num_of_head, head_dim), 0, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1]))), + f"lm.transformer.x_layers_{i}.self_attention.key.w": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.kv.kernel", + (hidden_dim, num_gqa_groups, head_dim), 0, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), + extra_src_paths = [f"lm.transformer.x_layers_{i}.self_attention.value.w"], + stack_dim = -3), + f"lm.transformer.x_layers_{i}.layer_norm.scale": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.query.scale", + None, + None, + just_copy=True), + }) + else: + raise RuntimeError("Unrecognized TE QKV layout in --te-qkv-layout.") + + if not self.skip_bias: + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer1.bias.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wi_bias", + (mlp_intermediate_dim,), None, lambda x: jnp.reshape(x, (1, *x.shape))), + f"lm.transformer.x_layers_{i}.ff_layer.ffn_layer2.bias.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.wo_bias", + None, + None, + just_copy=True), + f"lm.transformer.x_layers_{i}.ff_layer.layer_norm.bias": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.mlp.ln_bias", + None, + None, + just_copy=True), + f"lm.transformer.x_layers_{i}.self_attention.post.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.out.bias", + None, + None, + just_copy=True), + }) + # QKV biases depend on PaxML and TE QKV layouts + if self.te_qkv_layout == 'qkv_packed': + if self.pax_split_qkv: + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.self_attention.q.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.bias", + (num_of_head, head_dim), None, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), + extra_src_paths = [f"lm.transformer.x_layers_{i}.self_attention.key.b", + f"lm.transformer.x_layers_{i}.self_attention.value.b"], + stack_dim = -3) + }) + else: + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.self_attention.combined_qkv.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.bias", + (3, num_of_head, head_dim), None, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1]))) + }) + + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.layer_norm.bias": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.qkv.ln_bias", + None, + None, + just_copy=True) + }) + + elif self.te_qkv_layout == 'kv_packed': + ckpt_map.update({ + f"lm.transformer.x_layers_{i}.self_attention.query.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.query.bias", + (num_of_head, head_dim), None, + lambda x: jnp.reshape(x, (x.shape[-2] * x.shape[-1], ))), + f"lm.transformer.x_layers_{i}.self_attention.key.b": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.kv.bias", + (num_gqa_groups, head_dim), None, + lambda x: jnp.reshape(x, (*x.shape[:-2], x.shape[-2] * x.shape[-1])), + extra_src_paths = [f"lm.transformer.x_layers_{i}.self_attention.value.b"], + stack_dim = -3), + f"lm.transformer.x_layers_{i}.layer_norm.bias": + self._get_convert_pkg( + f"lm.transformer.x_layers_{i}.transformerlayer.cld.attention.query.ln_bias", + None, + None, + just_copy=True) + }) + else: + raise RuntimeError("Unrecognized TE QKV layout in --te-qkv-layout.") + return ckpt_map diff --git a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py index c9fd3defb..952058939 100644 --- a/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py +++ b/rosetta/utils/te_pax_t5x_ckpt_converter/converter/utils.py @@ -37,6 +37,7 @@ class ModelConfig: num_of_layer: int embed_dim: int num_of_head: int + num_gqa_groups: int head_dim: int mlp_intermediate_dim: int kernel_chunk_size: int = None @@ -56,12 +57,17 @@ class ConvertPkg: class ConvertHelper: def __init__(self, input_path: str, output_path: str, model_config: ModelConfig, - weight_only: bool, skip_ln: bool): + weight_only: bool, skip_bias: bool, skip_ln: bool, use_gated_act: bool, + pax_split_qkv: bool, te_qkv_layout: str): self.input_path = input_path self.output_path = output_path self.model_config = model_config self.weight_only = weight_only + self.skip_bias = skip_bias self.skip_ln = skip_ln + self.use_gated_act = use_gated_act + self.pax_split_qkv = pax_split_qkv + self.te_qkv_layout = te_qkv_layout @property def catagories(self):