Skip to content

Commit

Permalink
Refactor to make more testable, add tests/CI.
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton committed Oct 8, 2024
1 parent 32e2e8b commit 62dccd1
Show file tree
Hide file tree
Showing 10 changed files with 1,288 additions and 688 deletions.
3 changes: 3 additions & 0 deletions .github/triage/jax_toolbox_triage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .main import main

__all__ = ["main"]
112 changes: 112 additions & 0 deletions .github/triage/jax_toolbox_triage/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import argparse
import datetime
import getpass
import os
import pathlib
import tempfile


def parse_args():
parser = argparse.ArgumentParser(
description="""
Triage failures in JAX/XLA-related tests. The expectation is that the given
test command is failing in recent versions, but that it passed in the past. The
script first triages the regression with a search of the nightly containers,
and then refines the search to a particular commit of JAX or XLA.""",
)

container_search_args = parser.add_argument_group(
title="Container-level search",
description="""
First, it is verified that the test command fails on the given end date, unless
both --end-date and --skip-precondition-checks were passed. Then, the program
searches backwards to find a container when the given test did pass. The
--start-date option can be used to speed up this search, if you already know a
date on which the test was passing. The earliest failure is located to within
--threshold-days days.""",
)
commit_search_args = parser.add_argument_group(
title="Commit-level search",
description="""
Second, the failure is localised to a commit of JAX or XLA by re-building and
re-testing inside the earliest container that demonstrates the failure. At each
point, the oldest JAX commit that is newer than XLA is used.""",
)
parser.add_argument(
"--container",
help="""
Container to use. Example: jax, pax, triton. Used to construct the URLs of
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
required=True,
)
parser.add_argument(
"--output-prefix",
default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"),
help="""
Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS.
An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is
written as PREFIX-debug.log, and a JSON summary is written as
PREFIX-summary.json""",
type=pathlib.Path,
)
parser.add_argument(
"--skip-precondition-checks",
action="store_true",
help="""
Skip checks that should pass by construction. This saves time, but may yield
incorrect results if you are not careful. Specifically this means that the test
is assumed to fail on --end-date (if specified), pass on --start-date (if
specified), and fail after recompilation in the earliest-known-failure
container. Careful use of this option, along with --start-date, --end-date and
--threshold-days, allows the container-level search to be skipped.""",
)
parser.add_argument(
"test_command",
nargs="+",
help="""
Command to execute inside the container. This should be as targeted as
possible.""",
)
container_search_args.add_argument(
"--end-date",
help="""
Initial estimate of the earliest nightly container date where the test case
fails. Defaults to the newest available nightly container date. If this and
--skip-precondition-checks are both set then it will not be verified that the
test case fails on this date.""",
type=lambda s: datetime.date.fromisoformat(s),
)
container_search_args.add_argument(
"--start-date",
help="""
Initial estimate of the latest nightly container date where the test case
passes. Defaults to the day before --end-date, but setting this to a date
further in the past may lead to faster convergence of the initial backwards
search for a date when the test case passed. If this and
--skip-precondition-checks are both set then the test case *must* pass on
this date, which will *not* be verified.""",
type=lambda s: datetime.date.fromisoformat(s),
)
container_search_args.add_argument(
"--threshold-days",
default=1,
help="""
Convergence threshold. Ideally, the container-level search will continue while
the number of days separating the last known success and first known failure is
smaller than this value. The minimum, and default, value is 1. Note that in
case of nightly build failures the search may finish without reaching this
threshold.""",
type=int,
)
commit_search_args.add_argument(
"--bazel-cache",
default=os.path.join(
tempfile.gettempdir(), f"{getpass.getuser()}-bazel-triage-cache"
),
help="""
Bazel cache to use when [re-]building JAX/XLA during the fine search. This can
be a remote cache server or a local directory. Using a persistent cache can
significantly speed up the commit-level search. By default, uses a temporary
directory including the name of the current user.""",
)
return parser.parse_args()
79 changes: 79 additions & 0 deletions .github/triage/jax_toolbox_triage/docker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
import pathlib
import subprocess
import typing


class DockerContainer:
def __init__(
self,
url: str,
*,
logger: logging.Logger,
mounts: typing.List[typing.Tuple[pathlib.Path, pathlib.Path]],
):
self._logger = logger
self._mount_args = []
for src, dst in mounts:
self._mount_args += ["-v", f"{src}:{dst}"]
self._url = url

def __enter__(self):
result = subprocess.run(
[
"docker",
"run",
"--detach",
# Otherwise bazel shutdown hangs.
"--init",
"--gpus=all",
"--shm-size=1g",
]
+ self._mount_args
+ [
self._url,
"sleep",
"infinity",
],
check=True,
encoding="utf-8",
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)
self._id = result.stdout.strip()
return self

def __exit__(self, *exc_info):
subprocess.run(
["docker", "stop", self._id],
check=True,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)

def exec(
self, command: typing.List[str], workdir=None
) -> subprocess.CompletedProcess:
"""
Run a command inside a persistent container.
"""
workdir = [] if workdir is None else ["--workdir", workdir]
return subprocess.run(
["docker", "exec"] + workdir + [self._id] + command,
encoding="utf-8",
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)

def check_exec(
self, cmd: typing.List[str], **kwargs
) -> subprocess.CompletedProcess:
result = self.exec(cmd, **kwargs)
if result.returncode != 0:
self._logger.fatal(
f"{' '.join(cmd)} exited with return code {result.returncode}"
)
self._logger.fatal(result.stdout)
self._logger.fatal(result.stderr)
result.check_returncode()
return result
Loading

0 comments on commit 62dccd1

Please sign in to comment.