diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index fff704174..e1a7513b0 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -130,7 +130,7 @@ case "${BATTERY}" in BAZEL_TARGET="${BAZEL_TARGET} //tests:image_test_gpu //tests:scipy_stats_test_gpu" ;; gpu) - JOBS_PER_GPU=16 + JOBS_PER_GPU=8 JOBS=$((NGPUS * JOBS_PER_GPU)) EXTRA_FLAGS="--jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow" BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"