diff --git a/.github/container/Dockerfile.maxtext.amd64 b/.github/container/Dockerfile.maxtext.amd64 index c3337b9e6..26370d2a5 100644 --- a/.github/container/Dockerfile.maxtext.amd64 +++ b/.github/container/Dockerfile.maxtext.amd64 @@ -20,13 +20,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/Dockerfile.maxtext.arm64 b/.github/container/Dockerfile.maxtext.arm64 index ad0f43b69..3000541c0 100644 --- a/.github/container/Dockerfile.maxtext.arm64 +++ b/.github/container/Dockerfile.maxtext.arm64 @@ -58,13 +58,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/Dockerfile.pax.amd64 b/.github/container/Dockerfile.pax.amd64 index ede475a10..52a7723ab 100644 --- a/.github/container/Dockerfile.pax.amd64 +++ b/.github/container/Dockerfile.pax.amd64 @@ -29,6 +29,9 @@ for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do pushd ${src} sed -i "s| @ git+https://github.com/google/flax||g" requirements.in sed -i "s| @ git+https://github.com/google/jax||g" requirements.in + ## we pin etils because newer etils versions are not compatible with the + ## version of TFDS required by Pax + sed -i "s/etils/etils==1.7.0/g" requirements.in if git diff --quiet; then echo "URL specs no longer present in select dependencies for ${src}" exit 1 diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index 73aee4163..f3e4e0715 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -17,11 +17,12 @@ apt-get clean rm -rf /var/lib/apt/lists/* -NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1 -if [[ -d "${NSYS202451}" ]]; then - # * can match at least sbsa-armv8 and x86 - (cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) -fi +for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do + if [[ -d "${NSYS}" ]]; then + # * can match at least sbsa-armv8 and x86 + (cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) + fi +done # Install extra dependencies needed for `nsys recipe ...` commands. These are # used by the nsys-jax wrapper script. diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 9e3aaee4f..4e72a33fb 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -6,7 +6,7 @@ import pathlib from typing import Any -from .protobuf import HloProto, xla_module_metadata +from .protobuf import HloProto, _host_memory_space, xla_module_metadata from .utils import make_child_mask, ProfilerData pd.options.mode.copy_on_write = True @@ -38,6 +38,11 @@ def align_profiler_data_timestamps( # Determine which collective size will be used for the alignment num_profiled_devices = len(comm_df.index.get_level_values("Device").unique()) max_collective_size = comm_df["CollectiveSize"].max() + if max_collective_size == 1: + print( + f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1" + ) + return frames, {} assert ( num_profiled_devices == max_collective_size ), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" @@ -193,13 +198,51 @@ def _get_message_size( "all-to-all", "collective-broadcast", "collective-permute-start", + "dynamic-slice", + "dynamic-update-slice", "reduce-scatter", } ), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" + + def _byte_size(inst) -> int: + size_bits = math.prod( + inst.shape.dimensions, + start=element_type_width(inst.shape.element_type), + ) + size_bytes, rem = divmod(size_bits, 8) + assert rem == 0 + return size_bytes + if comm_inst.opcode == "collective-permute-start": # See https://openxla.org/xla/operation_semantics#collectivepermute, which # generates pair-wise send+recv between devices collective_size = 2 + elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}: + # Label host-device transfers orchestrated by dynamic[-update]-slice as single + # device collectives. + collective_size = 1 + if comm_inst.opcode == "dynamic-update-slice": + # For dynamic-update-slice the second operand is the one being copied + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1]) + transfer_size = _byte_size(src_inst.proto()) + else: + # For dynamic-slice the return type size is the transfer size + assert comm_inst.opcode == "dynamic-slice" + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0]) + transfer_size = _byte_size(comm_inst) + dest_on_host = _host_memory_space(comm_inst) + src_on_host = _host_memory_space(src_inst.proto()) + assert src_on_host != dest_on_host, ( + 'dynamic[-update]-slice is only considered is only "communication" if it ' + "represents a host-device transfer" + ) + return ( + transfer_size, + "device-to-host" if dest_on_host else "host-to-device", + 1, # collective size + 1.0, # bw_correction + 1.0, # bus_correction + ) else: # replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8 # devices that are doing pair-wise collectives @@ -220,17 +263,12 @@ def _get_message_size( total_msg_size = 0 for operand_id in comm_inst.operand_ids: _, operand = module_proto.find_instruction_by_id(operand_id) - msg_size_bits = math.prod( - operand.proto().shape.dimensions, - start=element_type_width(operand.proto().shape.element_type), - ) + msg_size_bytes = _byte_size(operand.proto()) if comm_inst.opcode == "reduce-scatter": # NCCL's convention is that the message size of a reduce-scatter is the size of output buffer: # https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122 - msg_size_bits, rem = divmod(msg_size_bits, collective_size) + msg_size_bytes, rem = divmod(msg_size_bytes, collective_size) assert rem == 0 - msg_size_bytes, rem = divmod(msg_size_bits, 8) - assert rem == 0 total_msg_size += msg_size_bytes collective = comm_inst.opcode.removesuffix("-start") diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index 6c25cb2ee..d6e4464bd 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -103,6 +103,9 @@ def is_communication(row): return _calculate_overlap(thunk_df) +compile_prefix = "XlaCompile:#module=" + + def _load_nvtx_gpu_proj_trace_single( prefix: pathlib.Path, file: pathlib.Path, @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single( unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates() if len(unique_pid_tid_pairs) == 1: main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) + # If the profile only includes N>1 modules, we may still be able to identify the + # main thread as the one responsible for XlaCompile ranges projected onto the GPU + # timeline + compile_ranges = df.loc[~all_thunks, "Name"].str.startswith( + tsl_prefix + compile_prefix + ) + compile_range_ids = compile_ranges[compile_ranges].index + unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates() + if len(unique_pid_tid_pairs) == 1: + main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) assert len(main_pid_tid_candidates) < 2 if len(main_pid_tid_candidates) == 1: # Possibly not correct if len(device_by_pid_tid) > 1 assert len(device_by_pid_tid) > 0 + # Associate the main thread with the 0th device in device_by_pid_tid main_thread_df = device_by_pid_tid.iloc[:1] main_thread_df.index = pd.MultiIndex.from_tuples( main_pid_tid_candidates, names=["PID", "TID"] @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace( return output -compile_prefix = "TSL:XlaCompile:#module=" - - def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame: # When parallel compilation is enabled, we end up with worker threads that # emit NVTX ranges but which are not accounted for in the RangeStack tree. # Splice these in under the relevant XlaCompile ranges in the RangeStack tree and # drop everything else. retain_mask = pd.Series(False, index=compile_df.index) - compile_mask = compile_df["Name"].str.startswith(compile_prefix) + compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix) for compile_range in compile_df[compile_mask].itertuples(): # Identify the slice of `compile_df` that overlaps in time with this XlaCompile # range diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py index ef74165fd..4feae6038 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py @@ -1,10 +1,13 @@ -from collections import defaultdict import functools import lzma import pathlib import typing +def _host_memory_space(inst): + return inst.shape.layout.memory_space == 5 + + class StackFrame(typing.NamedTuple): column: int file: str @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto): # proto representing the actual collective, which will be different if the # async launch is handled by an async-start op # TODO: can any of copy-start, custom-call, recv, send represent communication? + # This also aims to identify, and (for now) flag as communication, kernels that + # implement device-to-host and host-to-device copies for memory offloading. + # For example, a device-to-host offload might look like + # computation { + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...) + # } + # async_computation { + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) annotation shows that a buffer is in host memory. + # A host-to-device load might look like + # computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...) + # } + # async_computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) memory space annotation is in a parameter instead of in the + # return value. + # For now, handling host-device kernels as single-device "collective" + # communication should be sufficient. self._comm_proto = None comm_opcodes = { "all-gather", @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto): "all-reduce-start", "collective-permute-start", } + + def _is_offloading_instruction(inst): + host_dest = _host_memory_space(inst) + + def _host_operand(i): + _, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i]) + return _host_memory_space(op.proto()) + + if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0): + return True + elif ( + inst.opcode == "dynamic-update-slice" + and host_dest == _host_operand(0) + and host_dest != _host_operand(1) + ): + return True + return False + if self._proto.opcode in comm_opcodes | comm_start_opcodes: self._comm_proto = self._proto - elif self._proto.opcode == "async-start": + elif self._proto.opcode in {"async-start", "fusion"}: + # fusion example: + # computation { + # param_0 = f32[...]{...:S(5)} parameter(0) + # ... + # ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...) + # } + # inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation # This might be thinly wrapping an opcode in `comm_opcodes` - other_opcodes = defaultdict(int) - for called_id in self._proto.called_computation_ids: - for called_inst in wrapped_hlo_proto.find_computation( - called_id - ).instructions: - if called_inst.opcode in comm_opcodes: + def _visit_computation(computation_id): + computation = wrapped_hlo_proto.find_computation(computation_id) + for called_inst in computation.instructions: + for called_id in called_inst.called_computation_ids: + _visit_computation(called_id) + if called_inst.opcode in comm_opcodes or _is_offloading_instruction( + called_inst + ): assert ( self._comm_proto is None ), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}" self._comm_proto = called_inst - else: - other_opcodes[called_inst.opcode] += 1 - assert ( - other_opcodes.keys() == {"parameter"} - ), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}" + + for called_id in self._proto.called_computation_ids: + _visit_computation(called_id) def communication_proto(self): return self._comm_proto @@ -68,12 +125,7 @@ def is_communication(self) -> bool: a little more complicated than you might hope, because async communications are not handled uniformly. """ - if self._comm_proto is None: - return False - assert ( - self._comm_proto.channel_id != 0 - ), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}" - return True + return self._comm_proto is not None def proto(self): """ diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 60ef1a001..e9d30a3bc 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -177,3 +177,8 @@ orbax-checkpoint: tracking_ref: main latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f mode: pip-vcs +pathwaysutils: + url: https://github.com/google/pathways-utils.git + tracking_ref: main + latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9 + mode: pip-vcs diff --git a/.github/container/maxtext-mha.patch b/.github/container/maxtext-mha.patch deleted file mode 100644 index af2f2feb0..000000000 --- a/.github/container/maxtext-mha.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/requirements.txt b/requirements.txt -index cae6c73..4b7a214 100644 ---- a/requirements.txt -+++ b/requirements.txt -@@ -17,8 +17,8 @@ pylint - pytest - pytype - sentencepiece==0.1.97 --tensorflow-text>=2.13.0 --tensorflow>=2.13.0 -+tensorflow-text==2.13.0 -+tensorflow==2.13.0 - tensorflow-datasets - tensorboardx - tensorboard-plugin-profile diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 80ad5b02f..6afbaace1 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -109,9 +109,6 @@ else fi for t in $*; do - if [[ "$t" != "//tests:"* ]]; then - t="//tests:${t}" - fi BAZEL_TARGET="${BAZEL_TARGET} $t" done diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 0492aa388..381ca2f16 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -528,13 +528,14 @@ jobs: STATISTICS_SCRIPT: | summary_line=$(tail -n1 test-te.log) errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}') - passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "passed") | .outcome' | wc -l) - failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "failed") | .outcome' | wc -l) + passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l) + failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l) total_tests=$((failed_tests + passed_tests)) echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT echo "ERRORS=${errors}" >> $GITHUB_OUTPUT echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT + TIMEOUT_MINUTES: 120 ARTIFACTS: | test-te.log pytest-report.jsonl diff --git a/.github/workflows/_test_unit.yaml b/.github/workflows/_test_unit.yaml index 2edfb2249..fa29557e0 100644 --- a/.github/workflows/_test_unit.yaml +++ b/.github/workflows/_test_unit.yaml @@ -19,6 +19,10 @@ on: type: string description: 'Test artifacts to collect' required: false + TIMEOUT_MINUTES: + type: number + description: 'Maximum test runtime, in minutes' + default: 60 jobs: runner: @@ -26,7 +30,7 @@ jobs: with: NAME: "A100" LABELS: "A100,${{ github.run_id }}" - TIME: "01:00:00" + TIME: "${{ inputs.TIMEOUT_MINUTES }}:00" secrets: inherit run-unit-test: @@ -67,6 +71,7 @@ jobs: - name: Run tests shell: bash -x -e {0} continue-on-error: true + timeout-minutes: ${{ inputs.TIMEOUT_MINUTES }} run: | ${{ inputs.EXECUTE }} diff --git a/README.md b/README.md index 1764c5f00..054f49ae8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,23 @@ -# JAX Toolbox +# **JAX Toolbox** +[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md) +[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status) + +JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html). + +## Frameworks and Supported Models +We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs. + +| Framework | Models | Use cases | Container | +| :--- | :---: | :---: | :---: | +| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` | +| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | +| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | +| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | +| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | +| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | + +# Build Pipeline Status @@ -22,242 +40,294 @@ - + + - @@ -267,26 +337,9 @@ In all of the above cases, `ghcr.io/nvidia/jax:XXX` points to the most recent nightly build of the container for `XXX`. These containers are also tagged as `ghcr.io/nvidia/jax:XXX-YYYY-MM-DD`, if a stable reference is required. -## Note -This repo currently hosts a public CI for JAX on NVIDIA GPUs and covers some JAX libraries like: [T5x](https://github.com/google-research/t5x), [PAXML](https://github.com/google/paxml), [Transformer Engine](https://github.com/NVIDIA/TransformerEngine), [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html) and others to come soon. - -## Frameworks and Supported Models -We currently support the following frameworks and models. More details about each model and the available containers can be found in their respective READMEs. - -| Framework | Supported Models | Use-cases | Container | -| :--- | :---: | :---: | :---: | -| [Paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | -| [T5X](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | -| [T5X](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | -| [Big Vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | -| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | -| maxtext| LLaMA, Gemma | pretraining | `ghcr.io/nvidia/jax:maxtext` | - -We will update this table as new models become available, so stay tuned. - ## Environment Variables -The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning: +The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning of XLA and NCCL: | XLA Flags | Value | Explanation | | --------- | ----- | ----------- | @@ -302,10 +355,10 @@ There are various other XLA flags users can set to improve performance. For a de For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page. -## Profiling JAX programs on GPU +## Profiling See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU. -## FAQ (Frequently Asked Questions) +## Frequently asked questions (FAQ)
`bus error` when running JAX in a docker container @@ -340,7 +393,6 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b
## JAX on Public Clouds - * AWS * [Add EFA integration](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-efa.html) * [SageMaker code sample](https://github.com/aws-samples/aws-samples-for-ray/tree/main/sagemaker/jax_alpa_language_model) diff --git a/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env new file mode 100644 index 000000000..d999f5b5e --- /dev/null +++ b/rosetta/rosetta/projects/maxtext/xla_flags/llama2-7b-1N8G.env @@ -0,0 +1,24 @@ +set -x +NUM_NODES=1 +NUM_GPUS=8 +THRESHOLD_BYTES=1073741824 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_gpu_enable_triton_gemm=false \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_all_gather_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS))) \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=$((THRESHOLD_BYTES/(NUM_NODES*NUM_GPUS*2))) \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset NUM_NODES NUM_GPUS THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/common.env b/rosetta/rosetta/projects/pax/xla_flags/common.env new file mode 100644 index 000000000..26c819143 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/common.env @@ -0,0 +1,13 @@ +set -x +THRESHOLD_BYTES=51200 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-126m64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/glam-64b64e.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env new file mode 100644 index 000000000..e5b97b466 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-126m.env @@ -0,0 +1,14 @@ +set -x +THRESHOLD_BYTES=33554432 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_enable_cudnn_fmha=false \ + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +unset THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-175b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/gpt-5b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env new file mode 100644 index 000000000..e48b76dcf --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -0,0 +1,25 @@ +set -x +ALL_REDUCE_THRESHOLD_BYTES=3221225472 +ALL_GATHER_THRESHOLD_BYTES=3221225472 +REDUCE_SCATTER_THRESHOLD_BYTES=402653184 +export XLA_FLAGS="\ + --xla_gpu_enable_latency_hiding_scheduler=true \ + --xla_allow_excess_precision \ + --xla_gpu_enable_highest_priority_async_stream=true \ + --xla_gpu_enable_triton_softmax_fusion=false \ + --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ + --xla_gpu_graph_level=0 \ + --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \ + --xla_gpu_enable_pipelined_all_gather=true \ + --xla_gpu_enable_pipelined_reduce_scatter=true \ + --xla_gpu_enable_pipelined_all_reduce=true \ + --xla_gpu_enable_while_loop_double_buffering=true \ + --xla_gpu_enable_all_gather_combine_by_dim=false \ + --xla_gpu_enable_reduce_scatter_combine_by_dim=false \ + --xla_disable_hlo_passes=rematerialization \ + --xla_gpu_enable_custom_fusions=true + " +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env new file mode 100644 index 000000000..8b0237170 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-70b.env @@ -0,0 +1,3 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +source $SCRIPT_DIR/common.env +unset SCRIPT_DIR diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env new file mode 100644 index 000000000..d1568e92c --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b-lora.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 +set +x diff --git a/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/pax/xla_flags/llama-7b.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/t5x/xla_flags/t5.env b/rosetta/rosetta/projects/t5x/xla_flags/t5.env new file mode 100644 index 000000000..bd4ae50d5 --- /dev/null +++ b/rosetta/rosetta/projects/t5x/xla_flags/t5.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.8 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env new file mode 100644 index 000000000..45140ed88 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base-highgbs.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.75 +set +x diff --git a/rosetta/rosetta/projects/vit/xla_flags/vit-base.env b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env new file mode 100644 index 000000000..882c9e9e8 --- /dev/null +++ b/rosetta/rosetta/projects/vit/xla_flags/vit-base.env @@ -0,0 +1,4 @@ +set -x +echo "$0 uses default XLA_FLAGS='${XLA_FLAGS:-}'" +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +set +x
- + - + ghcr.io/nvidia/jax:base +
+ [no tests] +
- + - + ghcr.io/nvidia/jax:jax +
- + - - + +
+ - +
- + - - + +
+ - - + +
+ - +
- + - - + +
+ - +
- + - - + +
+ - +
- + - + ghcr.io/nvidia/jax:levanter - - + + + +
+ + +
- + - - + +
+ - +
- + - + ghcr.io/nvidia/jax:equinox - - + + + +
+ + + +
+ [tests disabled] +
- + - + ghcr.io/nvidia/jax:triton - + + + - + - - + +
+ - +
- + - + ghcr.io/nvidia/jax:upstream-t5x - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:t5x - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:upstream-pax - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:pax - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:maxtext - - + + + +
+ + +
- + - +
- + - + ghcr.io/nvidia/jax:gemma - + + + - + - - + +
+ - +