diff --git a/.github/triage/triage b/.github/triage/triage new file mode 100755 index 000000000..b0efe4f64 --- /dev/null +++ b/.github/triage/triage @@ -0,0 +1,688 @@ +#!/usr/bin/env python3 +import argparse +import datetime +import getpass +import json +import logging +import os +import subprocess +import tempfile +import time +from typing import List, Optional, Tuple + +# Because this script needs to run on compute clusters *outside* the containers that it +# orchestrates, it tries to tolerate extremely old Python versions. 3.6 was tested. + +parser = argparse.ArgumentParser( + description=""" + Triage failures in JAX/XLA-related tests. The expectation is that the given + test command is failing in recent versions, but that it passed in the past. The + script first triages the regression with a search of the nightly containers, + and then refines the search to a particular commit of JAX or XLA.""", + prog="triage.py", +) + +container_search_args = parser.add_argument_group( + title="Container-level search", + description=""" + First, it is verified that the test command fails on the given end date, unless + both --end-date and --skip-precondition-checks were passed. Then, the program + searches backwards to find a container when the given test did pass. The + --start-date option can be used to speed up this search, if you already know a + date on which the test was passing. The earliest failure is located to within + --threshold-days days.""", +) +commit_search_args = parser.add_argument_group( + title="Commit-level search", + description=""" + Second, the failure is localised to a commit of JAX or XLA by re-building and + re-testing inside the earliest container that demonstrates the failure. At each + point, the oldest JAX commit that is newer than XLA is used.""", +) +parser.add_argument( + "--container", + help=""" + Container to use. Example: jax, pax, triton. Used to construct the URLs of + nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""", + required=True, +) +parser.add_argument( + "--output-prefix", + default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"), + help=""" + Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS. + An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is + written as PREFIX-debug.log, and a JSON summary is written as + PREFIX-summary.json""", +) +parser.add_argument( + "--skip-precondition-checks", + action="store_true", + help=""" + Skip checks that should pass by construction. This saves time, but may yield + incorrect results if you are not careful. Specifically this means that the test + is assumed to fail on --end-date (if specified), pass on --start-date (if + specified), and fail after recompilation in the earliest-known-failure + container. Careful use of this option, along with --start-date, --end-date and + --threshold-days, allows the container-level search to be skipped.""", +) +parser.add_argument( + "test_command", + nargs="+", + help=""" + Command to execute inside the container. This should be as targeted as + possible.""", +) +container_search_args.add_argument( + "--end-date", + help=""" + Initial estimate of the earliest nightly container date where the test case + fails. Defaults to the newest available nightly container date. If this and + --skip-precondition-checks are both set then it will not be verified that the + test case fails on this date.""", + type=lambda s: datetime.date.fromisoformat(s), +) +container_search_args.add_argument( + "--start-date", + help=""" + Initial estimate of the latest nightly container date where the test case + passes. Defaults to the day before --end-date, but setting this to a date + further in the past may lead to faster convergence. If this and + --skip-precondition-checks are both set then it will not be verified that the + test case passes on this date.""", + type=lambda s: datetime.date.fromisoformat(s), +) +container_search_args.add_argument( + "--threshold-days", + default=1, + help=""" + Convergence threshold. Ideally, the container-level search will continue while + the number of days separating the last known success and first known failure is + smaller than this value. The minimum, and default, value is 1. Note that in + case of nightly build failures the search may finish without reaching this + threshold.""", + type=int, +) +commit_search_args.add_argument( + "--bazel-cache", + default=os.path.join( + tempfile.gettempdir(), f"{getpass.getuser()}-bazel-triage-cache" + ), + help=""" + Bazel cache to use when [re-]building JAX/XLA during the fine search. This can + be a remote cache server or a local directory. Using a persistent cache can + significantly speed up the commit-level search. By default, uses a temporary + directory including the name of the current user.""", +) +args = parser.parse_args() + + +if ( + args.bazel_cache.startswith("http://") + or args.bazel_cache.startswith("https://") + or args.bazel_cache.startswith("grpc://") +): + # Remote cache, no mount needed + bazel_cache_mount = [] +elif os.path.isabs(args.bazel_cache): + os.makedirs(args.bazel_cache, exist_ok=True) + bazel_cache_mount = ["-v", f"{args.bazel_cache}:{args.bazel_cache}"] +else: + raise Exception( + "--bazel-cache should be an http/https/grpc URL or an absolute path" + ) + + +def get_logger() -> logging.Logger: + logger = logging.getLogger("triage") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + fmt="[%(levelname)s] %(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + console = logging.StreamHandler() + trace_file = logging.FileHandler(filename=args.output_prefix + ".log", mode="w") + debug_file = logging.FileHandler( + filename=args.output_prefix + "-debug.log", mode="w" + ) + console.setLevel(logging.INFO) + trace_file.setLevel(logging.INFO) + debug_file.setLevel(logging.DEBUG) + console.setFormatter(formatter) + trace_file.setFormatter(formatter) + debug_file.setFormatter(formatter) + logger.addHandler(console) + logger.addHandler(trace_file) + logger.addHandler(debug_file) + return logger + + +logger = get_logger() + + +def container_url(date: datetime.date) -> str: + """ + Construct the URL for --container on the given date. + + Arguments: + date: YYYY-MM-DD format. + """ + # Around 2024-02-09 the naming scheme changed. + if date > datetime.date(year=2024, month=2, day=9): + return f"ghcr.io/nvidia/jax:{args.container}-{date.isoformat()}" + else: + return f"ghcr.io/nvidia/{args.container}:nightly-{date.isoformat()}" + + +def container_exists(date: datetime.date) -> bool: + """ + Check if the given container exists. + """ + result = subprocess.run( + ["docker", "pull", container_url(date)], + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + encoding="utf-8", + ) + logger.debug(result.stdout) + return result.returncode == 0 + + +def as_datetime(date: datetime.date) -> datetime.datetime: + return datetime.datetime.combine(date, datetime.time()) + + +def adjust_date( + date: datetime.datetime, + before: Optional[datetime.date] = None, + after: Optional[datetime.date] = None, + max_steps: int = 100, +) -> Optional[datetime.date]: + """ + Given a datetime that may have non-zero hour/minute/second/... parts, and where + container_url(date.date()) might be a container that does not exist due to job + failure, return a similar date where container_url(new_date) does exist, or None if + no such container can be found. + + Arguments: + date: date to adjust + before: the returned date will be before this [optional] + after: the returned date will be after this [optional] + max_steps: maximum number of days away from the start date to venture + """ + round_up = date.time() > datetime.time(12) + down, up = (date.date(), -1), (date.date() + datetime.timedelta(days=1), +1) + options = [up, down] if round_up else [down, up] + n = 0 + while n < max_steps: + plausible_directions = 0 + for start, direction in options: + candidate = start + n * direction * datetime.timedelta(days=1) + if (before is None or candidate < before) and ( + after is None or candidate > after + ): + plausible_directions += 1 + if container_exists(candidate): + if date.date() != candidate: + logger.debug(f"Adjusted {date} to {candidate}") + return candidate + else: + logger.debug( + f"{args.container} container {candidate} does not exist" + ) + n += 1 + if plausible_directions == 0: + logger.info( + f"Could not adjust {date} given before={before} and after={after}" + ) + return None + logger.info(f"Could not find an adjusted {date} within {max_steps} steps") + return None + + +class Container: + def __init__(self, date: datetime.date): + self._date = date + + def __enter__(self): + result = subprocess.run( + [ + "docker", + "run", + "--detach", + # Otherwise bazel shutdown hangs. + "--init", + "--gpus=all", + "--shm-size=1g", + ] + + bazel_cache_mount + + [ + container_url(self._date), + "sleep", + "infinity", + ], + check=True, + encoding="utf-8", + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + self._id = result.stdout.strip() + return self + + def __exit__(self, *exc_info): + subprocess.run( + ["docker", "stop", self._id], + check=True, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + def exec(self, command: List[str], workdir=None) -> subprocess.CompletedProcess: + """ + Run a command inside a persistent container. + """ + workdir = [] if workdir is None else ["--workdir", workdir] + return subprocess.run( + ["docker", "exec"] + workdir + [self._id] + command, + encoding="utf-8", + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ) + + def check_exec(self, cmd: List[str], **kwargs) -> subprocess.CompletedProcess: + result = self.exec(cmd, **kwargs) + if result.returncode != 0: + logger.fatal(f"{' '.join(cmd)} exited with return code {result.returncode}") + logger.fatal(result.stdout) + logger.fatal(result.stderr) + result.check_returncode() + return result + + +def add_summary_record(section, record, scalar=False): + """ + Add a record to the output JSON file. This is intended to provide a useful record + even in case of a fatal error. + """ + summary_filename = args.output_prefix + "-summary.json" + try: + with open(summary_filename, "r") as ifile: + data = json.load(ifile) + except FileNotFoundError: + data = {} + if scalar: + if section in data: + logging.warning(f"Overwriting summary data in section {section}") + data[section] = record + else: + if section not in data: + data[section] = [] + data[section].append(record) + with open(summary_filename, "w") as ofile: + json.dump(data, ofile) + + +def get_commit(container: Container, repo: str) -> Tuple[str, str]: + """ + Get the commit of the given repository that was used in the given nightly container + + Arguments: + date: nightly container date + repo: repository, must be jax or xla + """ + assert repo in {"jax", "xla"} + # Older containers used /opt/jax-source etc. + for suffix in ["", "-source"]: + dirname = f"/opt/{repo}{suffix}" + result = container.exec(["git", "rev-parse", "HEAD"], workdir=dirname) + if result.returncode == 0: + commit = result.stdout.strip() + if len(commit) == 40: + return commit, dirname + raise Exception( + f"Could not extract commit of {repo} from {args.container} container {container._date}" + ) + + +def check_container(date: datetime.date) -> bool: + """ + See if the test passes in the given container. + """ + before = time.monotonic() + with Container(date) as worker: + result = worker.exec(args.test_command) + test_time = time.monotonic() - before + jax_commit = get_commit(worker, "jax") + xla_commit = get_commit(worker, "xla") + + logger.debug(result.stdout) + logger.info(f"Ran test case in {date} in {test_time:.1f}s") + test_pass = result.returncode == 0 + add_summary_record( + "container", + { + "container": container_url(date), + "jax": jax_commit, + "result": test_pass, + "test_time": test_time, + "xla": xla_commit, + }, + ) + return test_pass + + +# Figure out the end date of the search +if args.end_date is not None: + # --end-date was passed + if not container_exists(args.end_date): + raise Exception(f"--end-date={args.end_date} is not a valid container") + end_date = args.end_date + skip_end_date_check = args.skip_precondition_checks +else: + # Default to the most recent container + now = datetime.datetime.now() + end_date = adjust_date(now) + if end_date is None: + raise Exception(f"Could not find a valid container from {now}") + skip_end_date_check = False + +# Check preconditions; the test is supposed to fail on the end date. +if skip_end_date_check: + logger.info(f"Skipping check for end-of-range failure in {end_date}") +else: + logger.info(f"Checking end-of-range failure in {end_date}") + if check_container(end_date): + raise Exception( + "Could not reproduce failure of `{}` on {} ({})".format( + " ".join(args.test_command), end_date, container_url(end_date) + ) + ) + +# Start the coarse, container-level, search for a starting point to the bisection range +earliest_failure = end_date +if args.start_date is None: + # Start from the day before the end date. + search_date = adjust_date( + as_datetime(end_date) - datetime.timedelta(days=1), before=end_date + ) + if search_date is None: + raise Exception(f"Could not find a valid nightly before {end_date}") + logger.info( + ( + f"Starting coarse search with {search_date} based on " + f"--end-date={args.end_date} and end_date={end_date}" + ) + ) + # We just found a starting value, we need to actually check if the test passes or + # fails on it. + skip_first_phase = False +else: + # If a start value seed was given, use it. + if not container_exists(args.start_date): + raise Exception(f"--start-date={args.start_date} is not a valid container") + search_date = args.start_date + assert search_date is not None # for mypy + # If --skip-precondition-checks and --start-date are both passed, we assume that + # the test passed on the given --start-date and the first phase of the search can + # be skipped + skip_first_phase = args.skip_precondition_checks + if not skip_first_phase: + logger.info(f"Starting coarse search with {search_date} based on --start-date") + +if skip_first_phase: + logger.info( + f"Skipping check that the test passes on --start-date={args.start_date}" + ) +else: + # While condition prints an info message + while not check_container(search_date): + # Test failed on `search_date`, go further into the past + earliest_failure = search_date + new_search_date = adjust_date( + as_datetime(end_date) - 2 * (end_date - search_date), + before=search_date, + ) + if new_search_date is None: + raise Exception(f"Could not find a valid nightly before {search_date}") + search_date = new_search_date + +# Continue the container-level search, refining the range until it meets the criterion +# set by args.threshold_days. The test passed at range_start and not at range_end. +range_start, range_end = search_date, earliest_failure +logger.info(f"Coarse container-level search yielded [{range_start}, {range_end}]...") +while range_end - range_start > datetime.timedelta(days=args.threshold_days): + range_mid = adjust_date( + as_datetime(range_start) + 0.5 * (range_end - range_start), + before=range_end, + after=range_start, + ) + if range_mid is None: + # It wasn't possible to refine further. + break + result = check_container(range_mid) + if result: + range_start = range_mid + else: + range_end = range_mid + logger.info(f"Refined container-level range to [{range_start}, {range_end}]") + +# Container-level search is now complete. Triage proceeds inside the range_end +# container. First, we check that rewinding JAX and XLA inside the range_end container +# to the commits used in the `range_start` container can reproduce the failure. +with Container(range_start) as worker: + start_jax_commit, _ = get_commit(worker, "jax") + start_xla_commit, _ = get_commit(worker, "xla") + +# Fire up the container that will be used for the fine search. +with Container(range_end) as worker: + end_jax_commit, jax_dir = get_commit(worker, "jax") + end_xla_commit, xla_dir = get_commit(worker, "xla") + logger.info( + ( + f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and " + f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}" + ) + ) + + # Get the full lists of JAX/XLA commits and dates + def commits(start, end, dir): + result = worker.check_exec( + [ + "git", + "log", + "--first-parent", + "--reverse", + "--format=%H %cI", + f"{start}^..{end}", + ], + workdir=dir, + ) + data = [] + for line in result.stdout.splitlines(): + commit, date = line.split() + date = datetime.datetime.fromisoformat(date).astimezone( + datetime.timezone.utc + ) + data.append((commit, date)) + return data + + # Get lists of (commit_hash, commit_date) pairs + jax_commits = commits(start_jax_commit, end_jax_commit, jax_dir) + xla_commits = commits(start_xla_commit, end_xla_commit, xla_dir) + # Confirm they're sorted by commit date + assert all(b[1] >= a[1] for a, b in zip(jax_commits, jax_commits[1:])) + assert all(b[1] >= a[1] for a, b in zip(xla_commits, xla_commits[1:])) + # Confirm the end values are included as expected + assert start_jax_commit == jax_commits[0][0] + assert start_xla_commit == xla_commits[0][0] + assert end_jax_commit == jax_commits[-1][0] + assert end_xla_commit == xla_commits[-1][0] + + def build_and_test(jax_commit: str, xla_commit: str) -> subprocess.CompletedProcess: + """ + The main body of the bisection loop. Update the JAX/XLA commits, build XLA and + jaxlib, and run the test command. Throws on error when checking out or + building, and returns the status of the test command. + """ + worker.check_exec(["git", "checkout", xla_commit], workdir=xla_dir) + worker.check_exec(["git", "checkout", jax_commit], workdir=jax_dir) + logger.info(f"Checking out XLA {xla_commit} JAX {jax_commit}") + # Build JAX + before = time.monotonic() + # Next two are workarounds for bugs in old containers + worker.check_exec(["sh", "-c", f"rm -v {jax_dir}/dist/jaxlib-*.whl"]) + worker.check_exec( + ["cp", f"{jax_dir}/jax/version.py", f"{jax_dir}/build/lib/jax/version.py"] + ) + # It seemed that this might be the origin of flaky behaviour. + worker.check_exec( + ["sh", "-c", "echo 'test --cache_test_results=no' > /root/.bazelrc"] + ) + build_jax = [ + "build-jax.sh", + # Leave the editable /opt/jax[-source] installation alone. Otherwise + # test-jax.sh is broken by having a /usr/... installation directory. + "--jaxlib_only", + # Workaround bugs in old containers where the default was wrong. + "--src-path-jax", + jax_dir, + f"--bazel-cache={args.bazel_cache}", + ] + worker.check_exec(build_jax, workdir=jax_dir) + middle = time.monotonic() + logger.info(f"Build completed in {middle - before:.1f}s") + # Run the test + test_result = worker.exec(args.test_command) + test_time = time.monotonic() - middle + add_summary_record( + "commit", + { + "build_time": middle - before, + "container": container_url(range_end), + "jax": jax_commit, + "result": test_result.returncode == 0, + "test_time": test_time, + "xla": xla_commit, + }, + ) + logger.info(f"Test completed in {test_time:.1f}s") + logger.debug( + f"Test stdout:\n{test_result.stdout}\nTest stderr:\n{test_result.stderr}" + ) + return test_result + + if args.skip_precondition_checks: + logger.info( + f"Skipping check that building + testing in {range_end} reproduces failure" + ) + else: + # Verify we can build successfully and that the test fails as expected. These + # commits are the ones already checked out in the container, but specifying + # them explicitly is good for the summary JSON. + logger.info(f"Building in the range-ending {range_end} container...") + range_end_result = build_and_test( + jax_commit=end_jax_commit, xla_commit=end_xla_commit + ) + if range_end_result.returncode != 0: + logger.info(f"Verified test failure after rebuilding in {range_end}") + else: + logger.fatal( + f"Could not reproduce test failure after rebuilding in {range_end} container" + ) + logger.fatal(range_end_result.stdout) + logger.fatal(range_end_result.stderr) + raise Exception(f"Could not reproduce") + + # Verify that we can build the commit at the start of the range and reproduce the + # test success there in the end-of-range container. + range_start_result = build_and_test( + jax_commit=start_jax_commit, xla_commit=start_xla_commit + ) + if range_start_result.returncode == 0: + logger.info( + f"Test passes after rebuilding commits from {range_start} in {range_end}" + ) + else: + logger.fatal( + f"Test failed after rebuilding commits from {range_start} in {range_end}" + ) + logger.fatal(range_start_result.stdout) + logger.fatal(range_start_result.stderr) + raise Exception(f"Could not reproduce") + + # Finally, start bisecting. This is XLA-centric; JAX is moved too but is secondary. + while len(xla_commits) > 2: + middle = len(xla_commits) // 2 + xla_hash, xla_date = xla_commits[middle] + # Find the oldest JAX commit that is newer than this + for jax_index, (jax_hash, jax_date) in enumerate(jax_commits): + if jax_date >= xla_date: + break + bisect_result = build_and_test(jax_commit=jax_hash, xla_commit=xla_hash) + if bisect_result.returncode == 0: + # Test passed, continue searching in the second half + xla_commits = xla_commits[middle:] + jax_commits = jax_commits[jax_index:] + else: + # Test failed, continue searching in the first half + xla_commits = xla_commits[: middle + 1] + jax_commits = jax_commits[: jax_index + 1] + + # XLA bisection converged. xla_commits has two entries. jax_commits may be a little + # longer, if it was more active than XLA at the relevant time. For example, here + # xla_commits is {oX, nX} and jax_commits is {oJ, mJ, nJ}, and the test passes with + # {oX, oJ} and fails with {nX, nJ}. Naming: o=old, m=medium, n=new, X=XLA, J=JAX. + # pass fail + # XLA: oX -------- nX + # JAX: oJ -- mJ -- nJ + # + # To figure out whether to blame XLA or JAX, we now test {oX, nJ}. + old_xla_hash = xla_commits[0][0] + new_jax_hash = jax_commits[-1][0] + blame_result = build_and_test(jax_commit=new_jax_hash, xla_commit=old_xla_hash) + if blame_result.returncode == 0: + # Test passed with {oX, nJ} but was known to fail with {nX, nJ}. Therefore, XLA + # commit nX is responsible and JAX is innocent. + results = (old_xla_hash, xla_commits[1][0]) + logger.info( + "Bisected failure to XLA {}..{} with JAX {}".format(*results, new_jax_hash) + ) + add_summary_record( + "result", + { + "container": container_url(range_end), + "jax_ref": new_jax_hash, + "xla_bad": xla_commits[1][0], + "xla_good": old_xla_hash, + }, + scalar=True, + ) + print("xla", *results, "jax", new_jax_hash) + else: + # Test failed with {oX, nJ} but was known to pass with {oX, oJ}, so JAX is + # responsible and we should bisect between oJ (pass) and nJ (fail). This yields + # a single JAX commit to blame, either mJ or nJ in the example above. + while len(jax_commits) > 2: + middle = len(jax_commits) // 2 + jax_hash, _ = jax_commits[middle] + bisect_result = build_and_test(jax_commit=jax_hash, xla_commit=old_xla_hash) + if bisect_result.returncode == 0: + # Test passsed, continue searching in second half + jax_commits = jax_commits[middle:] + else: + # Test failed, continue searching in the first half + jax_commits = jax_commits[: middle + 1] + results = (jax_commits[0][0], jax_commits[1][0]) + logger.info( + "Bisected failure to JAX {}..{} with XLA {}".format(*results, old_xla_hash) + ) + add_summary_record( + "result", + { + "container": container_url(range_end), + "jax_bad": jax_commits[1][0], + "jax_good": jax_commits[0][0], + "xla_ref": old_xla_hash, + }, + scalar=True, + ) + print("jax", *results, "xla", old_xla_hash) diff --git a/docs/triage.md b/docs/triage.md index 27e10c617..9ec5eafe5 100644 --- a/docs/triage.md +++ b/docs/triage.md @@ -4,6 +4,12 @@ There is a Github Action Workflow called [_triage.yaml](../.github/workflows/_tr be used to help determine if a test failure was due to a change in (t5x or pax) or further-up, e.g., in (Jax or CUDA). This workflow is not the end-all, and further investigation is usually needed, but this automates the investigation of questions like "what state of library X works with Jax at state Y?" +__Note__: There is also a utility, [triage](../.github/triage/triage), which can be +used for more granular bisection of failures in specific tests. Run it with `--help` +for usage instructions. Given a test expression that can be run inside the nightly +containers (*e.g.* `test-jax.sh jet_test_gpu`), it first identifies the nightly +container where the failure first appeared, and second attributes the failure to a +specific commit of JAX or XLA. ## Algorithm The pseudocode for the triaging algorithm is as follows: