diff --git a/.github/container/Dockerfile.maxtext.amd64 b/.github/container/Dockerfile.maxtext.amd64 index c3337b9e6..26370d2a5 100644 --- a/.github/container/Dockerfile.maxtext.amd64 +++ b/.github/container/Dockerfile.maxtext.amd64 @@ -20,13 +20,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/Dockerfile.maxtext.arm64 b/.github/container/Dockerfile.maxtext.arm64 index ad0f43b69..3000541c0 100644 --- a/.github/container/Dockerfile.maxtext.arm64 +++ b/.github/container/Dockerfile.maxtext.arm64 @@ -58,13 +58,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT} echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in EOF -############################################################################### -## Apply patch -############################################################################### - -ADD maxtext-mha.patch /opt -RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff - ############################################################################### ## Add test script to the path ############################################################################### diff --git a/.github/container/Dockerfile.pax.amd64 b/.github/container/Dockerfile.pax.amd64 index ede475a10..52a7723ab 100644 --- a/.github/container/Dockerfile.pax.amd64 +++ b/.github/container/Dockerfile.pax.amd64 @@ -29,6 +29,9 @@ for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do pushd ${src} sed -i "s| @ git+https://github.com/google/flax||g" requirements.in sed -i "s| @ git+https://github.com/google/jax||g" requirements.in + ## we pin etils because newer etils versions are not compatible with the + ## version of TFDS required by Pax + sed -i "s/etils/etils==1.7.0/g" requirements.in if git diff --quiet; then echo "URL specs no longer present in select dependencies for ${src}" exit 1 diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index 73aee4163..f3e4e0715 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -17,11 +17,12 @@ apt-get clean rm -rf /var/lib/apt/lists/* -NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1 -if [[ -d "${NSYS202451}" ]]; then - # * can match at least sbsa-armv8 and x86 - (cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) -fi +for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do + if [[ -d "${NSYS}" ]]; then + # * can match at least sbsa-armv8 and x86 + (cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) + fi +done # Install extra dependencies needed for `nsys recipe ...` commands. These are # used by the nsys-jax wrapper script. diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 9e3aaee4f..4e72a33fb 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -6,7 +6,7 @@ import pathlib from typing import Any -from .protobuf import HloProto, xla_module_metadata +from .protobuf import HloProto, _host_memory_space, xla_module_metadata from .utils import make_child_mask, ProfilerData pd.options.mode.copy_on_write = True @@ -38,6 +38,11 @@ def align_profiler_data_timestamps( # Determine which collective size will be used for the alignment num_profiled_devices = len(comm_df.index.get_level_values("Device").unique()) max_collective_size = comm_df["CollectiveSize"].max() + if max_collective_size == 1: + print( + f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1" + ) + return frames, {} assert ( num_profiled_devices == max_collective_size ), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented" @@ -193,13 +198,51 @@ def _get_message_size( "all-to-all", "collective-broadcast", "collective-permute-start", + "dynamic-slice", + "dynamic-update-slice", "reduce-scatter", } ), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated" + + def _byte_size(inst) -> int: + size_bits = math.prod( + inst.shape.dimensions, + start=element_type_width(inst.shape.element_type), + ) + size_bytes, rem = divmod(size_bits, 8) + assert rem == 0 + return size_bytes + if comm_inst.opcode == "collective-permute-start": # See https://openxla.org/xla/operation_semantics#collectivepermute, which # generates pair-wise send+recv between devices collective_size = 2 + elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}: + # Label host-device transfers orchestrated by dynamic[-update]-slice as single + # device collectives. + collective_size = 1 + if comm_inst.opcode == "dynamic-update-slice": + # For dynamic-update-slice the second operand is the one being copied + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1]) + transfer_size = _byte_size(src_inst.proto()) + else: + # For dynamic-slice the return type size is the transfer size + assert comm_inst.opcode == "dynamic-slice" + _, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0]) + transfer_size = _byte_size(comm_inst) + dest_on_host = _host_memory_space(comm_inst) + src_on_host = _host_memory_space(src_inst.proto()) + assert src_on_host != dest_on_host, ( + 'dynamic[-update]-slice is only considered is only "communication" if it ' + "represents a host-device transfer" + ) + return ( + transfer_size, + "device-to-host" if dest_on_host else "host-to-device", + 1, # collective size + 1.0, # bw_correction + 1.0, # bus_correction + ) else: # replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8 # devices that are doing pair-wise collectives @@ -220,17 +263,12 @@ def _get_message_size( total_msg_size = 0 for operand_id in comm_inst.operand_ids: _, operand = module_proto.find_instruction_by_id(operand_id) - msg_size_bits = math.prod( - operand.proto().shape.dimensions, - start=element_type_width(operand.proto().shape.element_type), - ) + msg_size_bytes = _byte_size(operand.proto()) if comm_inst.opcode == "reduce-scatter": # NCCL's convention is that the message size of a reduce-scatter is the size of output buffer: # https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122 - msg_size_bits, rem = divmod(msg_size_bits, collective_size) + msg_size_bytes, rem = divmod(msg_size_bytes, collective_size) assert rem == 0 - msg_size_bytes, rem = divmod(msg_size_bits, 8) - assert rem == 0 total_msg_size += msg_size_bytes collective = comm_inst.opcode.removesuffix("-start") diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index 6c25cb2ee..d6e4464bd 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -103,6 +103,9 @@ def is_communication(row): return _calculate_overlap(thunk_df) +compile_prefix = "XlaCompile:#module=" + + def _load_nvtx_gpu_proj_trace_single( prefix: pathlib.Path, file: pathlib.Path, @@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single( unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates() if len(unique_pid_tid_pairs) == 1: main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) + # If the profile only includes N>1 modules, we may still be able to identify the + # main thread as the one responsible for XlaCompile ranges projected onto the GPU + # timeline + compile_ranges = df.loc[~all_thunks, "Name"].str.startswith( + tsl_prefix + compile_prefix + ) + compile_range_ids = compile_ranges[compile_ranges].index + unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates() + if len(unique_pid_tid_pairs) == 1: + main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0])) assert len(main_pid_tid_candidates) < 2 if len(main_pid_tid_candidates) == 1: # Possibly not correct if len(device_by_pid_tid) > 1 assert len(device_by_pid_tid) > 0 + # Associate the main thread with the 0th device in device_by_pid_tid main_thread_df = device_by_pid_tid.iloc[:1] main_thread_df.index = pd.MultiIndex.from_tuples( main_pid_tid_candidates, names=["PID", "TID"] @@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace( return output -compile_prefix = "TSL:XlaCompile:#module=" - - def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame: # When parallel compilation is enabled, we end up with worker threads that # emit NVTX ranges but which are not accounted for in the RangeStack tree. # Splice these in under the relevant XlaCompile ranges in the RangeStack tree and # drop everything else. retain_mask = pd.Series(False, index=compile_df.index) - compile_mask = compile_df["Name"].str.startswith(compile_prefix) + compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix) for compile_range in compile_df[compile_mask].itertuples(): # Identify the slice of `compile_df` that overlaps in time with this XlaCompile # range diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py index ef74165fd..4feae6038 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py @@ -1,10 +1,13 @@ -from collections import defaultdict import functools import lzma import pathlib import typing +def _host_memory_space(inst): + return inst.shape.layout.memory_space == 5 + + class StackFrame(typing.NamedTuple): column: int file: str @@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto): # proto representing the actual collective, which will be different if the # async launch is handled by an async-start op # TODO: can any of copy-start, custom-call, recv, send represent communication? + # This also aims to identify, and (for now) flag as communication, kernels that + # implement device-to-host and host-to-device copies for memory offloading. + # For example, a device-to-host offload might look like + # computation { + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...) + # } + # async_computation { + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) annotation shows that a buffer is in host memory. + # A host-to-device load might look like + # computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...) + # } + # async_computation { + # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0) + # ... + # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation + # } + # start = (...) async-start(...), calls=async_computation + # where the :S(5) memory space annotation is in a parameter instead of in the + # return value. + # For now, handling host-device kernels as single-device "collective" + # communication should be sufficient. self._comm_proto = None comm_opcodes = { "all-gather", @@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto): "all-reduce-start", "collective-permute-start", } + + def _is_offloading_instruction(inst): + host_dest = _host_memory_space(inst) + + def _host_operand(i): + _, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i]) + return _host_memory_space(op.proto()) + + if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0): + return True + elif ( + inst.opcode == "dynamic-update-slice" + and host_dest == _host_operand(0) + and host_dest != _host_operand(1) + ): + return True + return False + if self._proto.opcode in comm_opcodes | comm_start_opcodes: self._comm_proto = self._proto - elif self._proto.opcode == "async-start": + elif self._proto.opcode in {"async-start", "fusion"}: + # fusion example: + # computation { + # param_0 = f32[...]{...:S(5)} parameter(0) + # ... + # ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...) + # } + # inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation # This might be thinly wrapping an opcode in `comm_opcodes` - other_opcodes = defaultdict(int) - for called_id in self._proto.called_computation_ids: - for called_inst in wrapped_hlo_proto.find_computation( - called_id - ).instructions: - if called_inst.opcode in comm_opcodes: + def _visit_computation(computation_id): + computation = wrapped_hlo_proto.find_computation(computation_id) + for called_inst in computation.instructions: + for called_id in called_inst.called_computation_ids: + _visit_computation(called_id) + if called_inst.opcode in comm_opcodes or _is_offloading_instruction( + called_inst + ): assert ( self._comm_proto is None ), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}" self._comm_proto = called_inst - else: - other_opcodes[called_inst.opcode] += 1 - assert ( - other_opcodes.keys() == {"parameter"} - ), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}" + + for called_id in self._proto.called_computation_ids: + _visit_computation(called_id) def communication_proto(self): return self._comm_proto @@ -68,12 +125,7 @@ def is_communication(self) -> bool: a little more complicated than you might hope, because async communications are not handled uniformly. """ - if self._comm_proto is None: - return False - assert ( - self._comm_proto.channel_id != 0 - ), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}" - return True + return self._comm_proto is not None def proto(self): """ diff --git a/.github/container/manifest.yaml b/.github/container/manifest.yaml index 60ef1a001..e9d30a3bc 100644 --- a/.github/container/manifest.yaml +++ b/.github/container/manifest.yaml @@ -177,3 +177,8 @@ orbax-checkpoint: tracking_ref: main latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f mode: pip-vcs +pathwaysutils: + url: https://github.com/google/pathways-utils.git + tracking_ref: main + latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9 + mode: pip-vcs diff --git a/.github/container/maxtext-mha.patch b/.github/container/maxtext-mha.patch deleted file mode 100644 index af2f2feb0..000000000 --- a/.github/container/maxtext-mha.patch +++ /dev/null @@ -1,15 +0,0 @@ -diff --git a/requirements.txt b/requirements.txt -index cae6c73..4b7a214 100644 ---- a/requirements.txt -+++ b/requirements.txt -@@ -17,8 +17,8 @@ pylint - pytest - pytype - sentencepiece==0.1.97 --tensorflow-text>=2.13.0 --tensorflow>=2.13.0 -+tensorflow-text==2.13.0 -+tensorflow==2.13.0 - tensorflow-datasets - tensorboardx - tensorboard-plugin-profile diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 80ad5b02f..6afbaace1 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -109,9 +109,6 @@ else fi for t in $*; do - if [[ "$t" != "//tests:"* ]]; then - t="//tests:${t}" - fi BAZEL_TARGET="${BAZEL_TARGET} $t" done diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 0492aa388..381ca2f16 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -528,13 +528,14 @@ jobs: STATISTICS_SCRIPT: | summary_line=$(tail -n1 test-te.log) errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}') - passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "passed") | .outcome' | wc -l) - failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "failed") | .outcome' | wc -l) + passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l) + failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l) total_tests=$((failed_tests + passed_tests)) echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT echo "ERRORS=${errors}" >> $GITHUB_OUTPUT echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT + TIMEOUT_MINUTES: 120 ARTIFACTS: | test-te.log pytest-report.jsonl diff --git a/.github/workflows/_test_unit.yaml b/.github/workflows/_test_unit.yaml index 2edfb2249..fa29557e0 100644 --- a/.github/workflows/_test_unit.yaml +++ b/.github/workflows/_test_unit.yaml @@ -19,6 +19,10 @@ on: type: string description: 'Test artifacts to collect' required: false + TIMEOUT_MINUTES: + type: number + description: 'Maximum test runtime, in minutes' + default: 60 jobs: runner: @@ -26,7 +30,7 @@ jobs: with: NAME: "A100" LABELS: "A100,${{ github.run_id }}" - TIME: "01:00:00" + TIME: "${{ inputs.TIMEOUT_MINUTES }}:00" secrets: inherit run-unit-test: @@ -67,6 +71,7 @@ jobs: - name: Run tests shell: bash -x -e {0} continue-on-error: true + timeout-minutes: ${{ inputs.TIMEOUT_MINUTES }} run: | ${{ inputs.EXECUTE }} diff --git a/README.md b/README.md index 1764c5f00..054f49ae8 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,23 @@ -# JAX Toolbox +# **JAX Toolbox** +[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md) +[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status) + +JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html). + +## Frameworks and Supported Models +We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs. + +| Framework | Models | Use cases | Container | +| :--- | :---: | :---: | :---: | +| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` | +| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` | +| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` | +| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` | +| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` | +| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` | + +# Build Pipeline Status
- + |
ghcr.io/nvidia/jax:base
|
+ |
- + | + [no tests] + |
- + |
ghcr.io/nvidia/jax:jax
|
+ |
-
-
+
- - - + - - + - - + |
|
- + |
ghcr.io/nvidia/jax:levanter
|
-
-
+
+
+
+ + + + |
- - + | |
- + |
ghcr.io/nvidia/jax:equinox
|
-
-
+
+
+
+ + + + + |
+ + [tests disabled] + | -|
- + |
ghcr.io/nvidia/jax:triton
|
- + + + | - - + | |
- + |
ghcr.io/nvidia/jax:upstream-t5x
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:t5x
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:upstream-pax
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:pax
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:maxtext
|
-
-
+
+
+
+ + + + |
- + | |
- + |
ghcr.io/nvidia/jax:gemma
|
- + + + | - - + |