diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 55af1d727..810109f32 100755 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -86,9 +86,10 @@ def check_container(date: datetime.date) -> bool: jax_commit = get_commit(worker, "jax") xla_commit = get_commit(worker, "xla") - logger.debug(result.stdout) - logger.info(f"Ran test case in {date} in {test_time:.1f}s") test_pass = result.returncode == 0 + logger.info(f"Ran test case in {date} in {test_time:.1f}s, pass={test_pass}") + logger.debug(result.stdout) + logger.debug(result.stderr) add_summary_record( "container", { diff --git a/README.md b/README.md index 054f49ae8..54e67d460 100644 --- a/README.md +++ b/README.md @@ -407,3 +407,4 @@ Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists b * [What's New in JAX | GTC Spring 2023](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51956/) * [Slurm and OpenMPI zero config integration](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html) * [Adding custom GPU ops](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) +* [Triaging regressions](docs/triage-tool.md) diff --git a/docs/triage-tool.md b/docs/triage-tool.md new file mode 100644 index 000000000..71cd35a33 --- /dev/null +++ b/docs/triage-tool.md @@ -0,0 +1,249 @@ +# Triage tool + +`jax-toolbox-triage` is a tool to automate the process of attributing regressions to an +individual commit of JAX or XLA. +It takes as input a command that returns an error (non-zero) code when run in "recent" +containers, but which returns a success (zero) code when run in some "older" container. +The command must be executable within the containers, *i.e.* it cannot refer to files +that only exist on the host system. + +The tool follows a three-step process: + 1. A container-level search backwards from the "recent" container where the test is + known to fail, which identifies an "older" container where the test passes. This + search proceeds with an exponentially increasing step size and is based on the + `YYYY-MM-DD` tags under `ghcr.io/nvidia/jax`. + 2. A container-level binary search to refine this to the **latest** available + container where test passes and the **earliest** available container where it + fails. + 3. A commit-level binary search, repeatedly building + testing inside the same + container, to identify a single commit of JAX (XLA) that causes the test to start + failing, and a reference commit of XLA (JAX) that can be used to reproduce the + regression. + +## Installation + +The triage tool can be installed using `pip`: +```bash +pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage +``` +or directly from a checkout of the JAX-Toolbox repository. +Because the tool needs to orchestrate running commands in multiple containers, it is +most convenient to install it in a virtual environment on the host system, rather than +attempting to install it inside a container. + +The tool should be invoked on a machine with `docker` available and whatever GPUs are +needed to execute the test case. + +## Usage + +To use the tool, there are two compulsory arguments: + * `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container + families to execute the test command in. Example: `jax` for a JAX unit test + failure, `maxtext` for a MaxText model execution failure + * A test command to triage. + +The test command will be executed directly in the container, not inside a shell, so be +sure not to add excessive quotation marks (*i.e.* run +`jax-toolbox-triage --container=jax test-jax.sh foo` not +`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it +as fast and targeted as possible. +The expectation is that the test case will be executed successfully several times as +part of the triage, so you may want to tune some parameters to reduce the execution +time in the successful case. +For example, if `text-maxtext.sh --steps=500 ...` is failing on step 0, you should +probably reduce `--steps` to optimise execution time in the successful case. + +A JSON status file and both info-level and debug-level logfiles are written to the +directory given by `--output-prefix`. + +### Optimising container-level search performance + +By default, the container-level search starts from the most recent available container, +if you already know that the test has been failing for a while, you can pass +`--end-date` to start the search further in the past. +If you are sure that the test is failing on the `--end-date` you have passed, you can +skip verification of that fact by passing `--skip-precondition-checks` (but see below +for other checks that this skips). + +By default, the container-level backwards search for a date on which the test passed +tries the containers approximately [1, 2, 4, ...] days before `--end-date`. +This can be tuned by passing `--start-date`, which overrides the "end date minus one" +start value (but leaves the exponential growth of the search range width). +If you are sure that the test is passing on the `--start-date` you have passed, you can +skip verification of that fact by passing `--skip-precondition-checks`. + +The combination of `--start-date`, `--end-date` and `--skip-precondition-checks` can be +used to skip the entire first stage of the bisection process. + +The second stage of the triage process can be made to abort early using the +`--threshold-days` option; this stage will terminate once the delta between the latest +known-good and earliest known-bad containers is below the threshold. + +If you need to re-start the tool for some reason, use of these options can help +bootstrap the tool using the results of a previous (partial) run. + +### Optimising commit-level search performance + +The third stage of the triage process involves repeatedly building JAX and XLA, which +can be sped up significantly using a Bazel cache. +By default, a local directory on the host machine (where the tool is being executed) +will be used, but it may be more efficient to use a persistent and/or pre-heated cache. +This can be achieved by passing the `--bazel-cache` option, which accepts absolute +paths and `http`/`https`/`grpc` URLs. + +If `--skip-precondition-checks` is passed, a sanity check that the failure can be +reproduced after rebuilding the JAX/XLA commits from the first-known-bad container +inside that container will be skipped. + +## Example + +Here is an example execution for a JAX unit test failure, with some annotation: +```console +user@gpu-machine $ jax-toolbox-triage --container jax test-jax.sh //tests:nn_test_gpu +``` +`--end-date` was not passed, and 2024-10-15 is the most recent available container +at the time of execution +``` +[INFO] 2024-10-16 00:31:41 Checking end-of-range failure in 2024-10-15 +``` +`--skip-precondition-checks` was not passed, so the tool checks that the test does, in +fact, fail in the 2024-10-15 container +``` +[INFO] 2024-10-16 00:33:36 Ran test case in 2024-10-15 in 114.8s, pass=False +``` +`--start-date` was not passed, so the first (backwards search) stage of the triage +process starts with the container 1 day before the end of the range, *i.e.* 2024-10-14 +``` +[INFO] 2024-10-16 00:33:37 Starting coarse search with 2024-10-14 based on end_date=2024-10-15 +[INFO] 2024-10-16 00:35:35 Ran test case in 2024-10-14 in 118.1s, pass=False +``` +`end_date - 2 * (end_date - search_date)` = `2024-10-15 - 2 days` = `2024-10-13` +``` +[INFO] 2024-10-16 00:38:11 Ran test case in 2024-10-13 in 122.4s, pass=False +``` +In principle this would be 4 days before the end date, but the 2024-10-11 container +does not exist, so the tool chooses a nearby container that does exist and is older +than 2024-10-13 +``` +[INFO] 2024-10-16 00:40:53 Ran test case in 2024-10-12 in 127.7s, pass=False +``` +Steps in date start to increase significantly +``` +[INFO] 2024-10-16 00:43:28 Ran test case in 2024-10-09 in 119.3s, pass=False +[INFO] 2024-10-16 00:45:29 Ran test case in 2024-10-03 in 120.7s, pass=False +[INFO] 2024-10-16 00:47:27 Ran test case in 2024-09-21 in 116.3s, pass=False +``` +The first stage of the triage process successfully identifies an old container where +this test passed +``` +[INFO] 2024-10-16 00:51:22 Ran test case in 2024-08-28 in 194.0s, pass=True +[INFO] 2024-10-16 00:51:22 Coarse container-level search yielded [2024-08-28, 2024-09-21]... +``` +The second stage of the triage process refines the container-level range by bisection +``` +[INFO] 2024-10-16 00:53:19 Ran test case in 2024-09-09 in 115.5s, pass=True +[INFO] 2024-10-16 00:53:19 Refined container-level range to [2024-09-09, 2024-09-21] +[INFO] 2024-10-16 00:56:03 Ran test case in 2024-09-15 in 125.4s, pass=True +[INFO] 2024-10-16 00:56:03 Refined container-level range to [2024-09-15, 2024-09-21] +[INFO] 2024-10-16 00:58:07 Ran test case in 2024-09-18 in 122.9s, pass=True +[INFO] 2024-10-16 00:58:07 Refined container-level range to [2024-09-18, 2024-09-21] +``` +The second stage of the triage process converges +``` +[INFO] 2024-10-16 01:00:09 Ran test case in 2024-09-19 in 121.2s, pass=False +[INFO] 2024-10-16 01:00:09 Refined container-level range to [2024-09-18, 2024-09-19] +``` +The third stage of the triage process begins, using: + - the first-known-bad container 2024-09-19 + - first-known-bad commits (JAX 9d2e9... and XLA 42b04...) + - last-known-good commits (JAX 988ed... and XLA 88935...) +``` +[INFO] 2024-10-16 01:00:10 Bisecting JAX [988ed2bd75df5fe25b74eaf38075aadff19be207, 9d2e9c688c4e8b733e68467d713091436a672ac0] and XLA [8893550a604fe39aae2eeae49a836e92eed497d1, 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7] using ghcr.io/nvidia/jax:jax-2024-09-19 +[INFO] 2024-10-16 01:00:10 Building in the range-ending container... +``` +Sanity check that re-building the first-known-bad commits in the first-known-bad +container reproduces the failure +``` +[INFO] 2024-10-16 01:00:12 Checking out XLA 42b04a6739dc648a80dd4f3b4e1322f1b2c7f3a7 JAX 9d2e9c688c4e8b733e68467d713091436a672ac0 +``` +No Bazel cache was passed, and this is the first build in the triage session, so it is +slow -- a full rebuild of JAX and XLA was needed +``` +[INFO] 2024-10-16 01:13:56 Build completed in 824.9s +[INFO] 2024-10-16 01:15:25 Test completed in 88.5s +[INFO] 2024-10-16 01:15:25 Verified test failure after vanilla rebuild +``` +Verification that the last-known-good commits still pass when rebuilt in the +first-known-bad container; this is a bit faster because the Bazel cache is warmer +``` +[INFO] 2024-10-16 01:15:25 Checking out XLA 8893550a604fe39aae2eeae49a836e92eed497d1 JAX 988ed2bd75df5fe25b74eaf38075aadff19be207 +[INFO] 2024-10-16 01:26:43 Build completed in 677.5s +[INFO] 2024-10-16 01:27:36 Test completed in 53.7s +[INFO] 2024-10-16 01:27:36 Test passed after rebuilding commits from start container in end container +``` +Binary search in commits continues, with progressively faster build times +``` +[INFO] 2024-10-16 01:27:37 Checking out XLA b976dd94f11ab130c5f718b360fcfb5ac6d6b875 JAX b51c65357f0ae9659e58e2ff0df871542124cddf +[INFO] 2024-10-16 01:32:24 Build completed in 287.7s +[INFO] 2024-10-16 01:33:19 Test completed in 54.4s +[INFO] 2024-10-16 01:33:19 Checking out XLA e291dfe0a12ec5907636a722c545c19d43f04c8b JAX 9dd363da1298e4810b693a918fc2e8199094acdb +[INFO] 2024-10-16 01:34:58 Build completed in 98.9s +[INFO] 2024-10-16 01:35:52 Test completed in 54.1s +[INFO] 2024-10-16 01:35:53 Checking out XLA 6e652a5d91657cfbe9fbcdff4a0ccd1b803675a7 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03 +[INFO] 2024-10-16 01:36:54 Build completed in 61.3s +[INFO] 2024-10-16 01:37:47 Test completed in 52.7s +[INFO] 2024-10-16 01:37:47 Checking out XLA a1299f86507c79c8acf877344d545f10329f8515 JAX b164d67d4a9bd094426ff450fe1f1335d3071d03 +[INFO] 2024-10-16 01:38:39 Build completed in 52.5s +[INFO] 2024-10-16 01:39:32 Test completed in 52.5s +[INFO] 2024-10-16 01:39:32 Checking out XLA 2d1f7b70740649a57ec4988702ae1dbdfeee6e9c JAX b164d67d4a9bd094426ff450fe1f1335d3071d03 +[INFO] 2024-10-16 01:40:24 Build completed in 52.2s +[INFO] 2024-10-16 01:41:17 Test completed in 52.9s +[INFO] 2024-10-16 01:41:17 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX 016c49951f670256ce4750cdfea182e3a2a15325 +[INFO] 2024-10-16 01:42:08 Build completed in 50.9s +[INFO] 2024-10-16 01:43:12 Test completed in 64.2s +``` +The XLA commit has stopped changing; the initial bisection is XLA-centric (with JAX +kept roughly in sync), but when this converges on a single XLA commit, the tool will +run extra tests to decide whether to blame that XLA commit or a nearby JAX commit +``` +[INFO] 2024-10-16 01:43:13 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX b164d67d4a9bd094426ff450fe1f1335d3071d03 +[INFO] 2024-10-16 01:44:01 Build completed in 48.8s +[INFO] 2024-10-16 01:45:02 Test completed in 60.8s +[INFO] 2024-10-16 01:45:03 Checking out XLA 662eb45a17c76df93e5a386929653ae4c1f593da JAX cd04d0f32e854aa754e37e4b676725655a94e731 +[INFO] 2024-10-16 01:45:52 Build completed in 49.4s +[INFO] 2024-10-16 01:46:53 Test completed in 60.7s +[INFO] 2024-10-16 01:46:53 Bisected failure to JAX cd04d0f32e854aa754e37e4b676725655a94e731..b164d67d4a9bd094426ff450fe1f1335d3071d03 with XLA 662eb45a17c76df93e5a386929653ae4c1f593da +``` + +Where the final result should be read as saying that the test passes with +[xla@662eb](https://github.com/openxla/xla/commit/662eb45a17c76df93e5a386929653ae4c1f593da) +and +[jax@cd04d](https://github.com/jax-ml/jax/commit/cd04d0f32e854aa754e37e4b676725655a94e731), +but that if JAX is moved forward to include +[jax@b164d](https://github.com/jax-ml/jax/commit/b164d67d4a9bd094426ff450fe1f1335d3071d03) +then the test fails. +This failure is fixed in [jax#24427](https://github.com/jax-ml/jax/pull/24427). + +## Limitations + +This tool aims to target the common case that regressions are due to commits in JAX or +XLA, so if the root cause is different it may not converge, although the partial results +may still be helpful. + +For example, if the regression is due to a new version of some other dependency +`SomeProject` that was first installed in the `2024-10-15` container, then the first +two stages of the triage process will correctly identify that `2024-10-15` is the +critical date, but the third stage will fail because it will try and fail to reproduce +test success by building the JAX/XLA commits from `2024-10-14` in the `2024-10-15` +container. + +Other limitations include that only `docker` is supported as a container runtime, which +also implies that it is not currently possible to triage a test that requires a +multi-node or multi-process test. + +The tool also does not currently handle skipping commits that do not compile, or test +cases that require copying files (*e.g.* script files) into the container. + +If you run into these limitations in real-world usage of this tool, please file a bug +against JAX-Toolbox including details of manual steps you took to root-case the test +regression. diff --git a/docs/triage.md b/docs/triage.md index 9ec5eafe5..7a9c905a6 100644 --- a/docs/triage.md +++ b/docs/triage.md @@ -4,12 +4,11 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed, but this automates the investigation of questions like "what state of library X works with Jax at state Y?" -__Note__: There is also a utility, [triage](../.github/triage/triage), which can be -used for more granular bisection of failures in specific tests. Run it with `--help` -for usage instructions. Given a test expression that can be run inside the nightly -containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly -container where the failure first appeared, and second attributes the failure to a -specific commit of JAX or XLA. +__Note__: There is also a [triage tool](triage-tool.md), which can be used for +more granular bisection of failures in specific tests. Given a test expression that can +be run inside the nightly containers (*e.g.* `test-jax.sh jet_test_gpu`), it first +identifies the nightly container where the failure first appeared, and second attributes +the failure to a specific commit of JAX or XLA. ## Algorithm The pseudocode for the triaging algorithm is as follows: