Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect recursively and filter GPU tests using jax_test_gpu tag #1091

Merged
merged 5 commits into from
Oct 21, 2024

Conversation

andportnoy
Copy link
Contributor

No description provided.

@andportnoy
Copy link
Contributor Author

Needs jax-ml/jax#24218.

@olupton olupton closed this Oct 10, 2024
@olupton olupton reopened this Oct 10, 2024
Copy link
Collaborator

@yhtang yhtang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like JAX unit tests resulted in 20k+ errors and my browser just spent the last 10 mins loading the error log...

@andportnoy
Copy link
Contributor Author

Yeah this confuses me because the errors that I see are all of the form

RuntimeError: Backend 'tpu' failed to initialize: INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory. Available backends are ['cpu', 'cuda']

in tests like //tests:blocked_sampler_test_tpu. But one would think these tests shouldn't be collected when filtering with --test_tag_filters=jax_test_gpu"...

@andportnoy
Copy link
Contributor Author

andportnoy commented Oct 10, 2024

The full command is

bazel test \
  --//jax:build_jaxlib=false \
  //tests/... \
  --test_tag_filters=jax_test_gpu \
  --@local_config_cuda//:enable_cuda \
  --cache_test_results=no \
  --test_timeout=600 \
  --test_tag_filters=-multiaccelerator \
  --test_env=JAX_SKIP_SLOW_TESTS=1 \
  --test_env=JAX_ACCELERATOR_COUNT=8 \
  --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
  --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false \
  --test_output=errors \
  --java_runtime_version=remotejdk_11 \
  --run_under /opt/jax/build/parallel_accelerator_execute.sh \
  --local_test_jobs=64 \
  --test_env=JAX_TESTS_PER_ACCELERATOR=8 \
  --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow

I wonder if --test_tag_filters cannot be repeated and so --test_tag_filters=-multiaccelerator overrides --test_tag_filters=jax_test_gpu?

Nope, --test_tag_filters cannot be repeated: bazelbuild/bazel#7322.

yhtang
yhtang previously approved these changes Oct 14, 2024
@andportnoy
Copy link
Contributor Author

jax-ml/jax#24322 fixes //tests/mosaic:flash_attention_gpu on V100/A100.

Copy link
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
//tests/mosaic:flash_attention_gpu is failing, but I don't think that should block merging this

@olupton olupton merged commit cfc3f74 into main Oct 21, 2024
139 of 143 checks passed
@olupton olupton deleted the aportnoy/use-jax_test_gpu branch October 21, 2024 11:43
@andportnoy
Copy link
Contributor Author

//tests/mosaic:flash_attention_gpu is failing

Fixed in jax-ml/jax#24467.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants