Skip to content

Commit

Permalink
Reduce JAX unit test to 8 tasks per GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
yhtang authored Sep 13, 2023
1 parent ae245a6 commit 87fe7ed
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion .github/container/test-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ case "${BATTERY}" in
BAZEL_TARGET="${BAZEL_TARGET} //tests:image_test_gpu //tests:scipy_stats_test_gpu"
;;
gpu)
JOBS_PER_GPU=16
JOBS_PER_GPU=8
JOBS=$((NGPUS * JOBS_PER_GPU))
EXTRA_FLAGS="--jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
Expand Down

0 comments on commit 87fe7ed

Please sign in to comment.