Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump the jax group with 2 updates #1461

Merged
merged 6 commits into from
Dec 12, 2024
Merged

Conversation

dependabot[bot]
Copy link
Contributor

@dependabot dependabot bot commented on behalf of github Dec 11, 2024

Resolves #1434

Updates the requirements on jax and diffrax to permit the latest version.
Updates jax to 0.4.37

Release notes

Sourced from jax's releases.

JAX v0.4.37

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug where jit would error if an argument was named f (#25329).
    • Fix a bug that will throw index out of range error in jax.lax.while_loop if the user registers pytree node class with different aux data for the flatten and flatten_with_path.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
Changelog

Sourced from jax's changelog.

jax 0.4.37 (Dec 9, 2024)

This is a patch release of jax 0.4.36. Only "jax" was released at this version.

  • Bug fixes
    • Fixed a bug where jit would error if an argument was named f (#25329).
    • Fix a bug that will throw index out of range error in {func}jax.lax.while_loop if the user register pytree node class with different aux data for the flatten and flatten_with_path.
    • Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.

jax 0.4.36 (Dec 5, 2024)

  • Breaking Changes
    • This release lands "stackless", an internal change to JAX's tracing machinery. We made trace dispatch purely a function of context rather than a function of both context and data. This let us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind, and so on. The change should only affect users that use JAX internals.

      If you do use JAX internals then you may need to update your code (see jax-ml/jax@c36e1f7 for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the config.jax_data_dependent_tracing_fallback flag as a workaround, and if you need help updating your code then please file a bug.

    • {func}jax.experimental.jax2tf.convert with native_serialization=False or with enable_xla=False have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. jax2tf with native serialization will still be supported.

    • In jax.interpreters.xla, the xb, xc, and xe symbols have been removed after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge, xc = jax.lib.xla_client, and xe = jax.lib.xla_extension.

    • The deprecated module jax.experimental.export has been removed. It was replaced by {mod}jax.export in JAX v0.4.30. See the migration guide for information on migrating to the new API.

    • The initial argument to {func}jax.nn.softmax and {func}jax.nn.log_softmax has been removed, after being deprecated in v0.4.27.

    • Calling np.asarray on typed PRNG keys (i.e. keys produced by :func:jax.random.key) now raises an error. Previously, this returned a scalar object array.

    • The following deprecated methods and functions in {mod}jax.export have been removed:

      • jax.export.DisabledSafetyCheck.shape_assertions: it had no effect already.
      • jax.export.Exported.lowering_platforms: use platforms.
      • jax.export.Exported.mlir_module_serialization_version: use calling_convention_version.

... (truncated)

Commits
  • ffb07cd Update versions for v0.4.37 release.
  • 95892fd Use private names for args in api_util to avoid shadowing kwargs keys.
  • 65b6088 Avoid index out of range error in carry structure check
  • 259194a [Pallas] Fix shard_axis in dma_start interpret mode rule.
  • 7e6620a JAX release 0.4.36.
  • 23d5c10 [Mosaic:TPU] Fix fully replicated relayout
  • 2a4a0e8 [jax:custom_partitioning] Implement SdyShardingRule to support
  • f73fa7a Merge pull request #25290 from jakevdp:reduction-where
  • a71f9a6 Merge pull request #25271 from jakevdp:fix-vector-norm
  • e20a483 [JAX] Add end-to-end execution support in colocated Python API
  • Additional commits viewable in compare view

Updates diffrax to 0.6.1

Release notes

Sourced from diffrax's releases.

Diffrax v0.6.1

Features

  • Compatibility with JAX 0.4.36.

  • New solvers! Added stochastic Runge--Kutta methods for solving the underdamped Langevin equation. We now have:

    • diffrax.AbstractFosterLangevinSRK
    • diffrax.ALIGN
    • diffrax.QUICSORT
    • diffrax.ShOULD

    and these are used with the corresponding

    • diffrax.UnderdampedLangevinDriftTerm
    • diffrax.UnderdampedLangevinDiffusionTerm

    huge thanks to @​andyElking for carefully implementing all of these, which was a huge technical task. (#453 and 2000 new lines of code!) See the Underdamped Langevin Diffusion example for more on how to use these.

Bugfixes

  • If t0 == t1 and we have SaveAt(ts=...) then we now correctly output len(ts) copies of y0. (Thanks @​dkweiss31! #488, #494)
  • When using diffrax.VirtualBrownianTree on the GPU then floating point fluctuations would sometimes produce evaluations outside of the valid [t0, t1] region, which would raise a spurious runtime error. This is now fixed. (Thanks @​mattlevine22! jax-ml/jax#24807, #524, #526)
  • Complex fixes in SDEs (Thanks @​Randl! #454)
  • Improvements to errors, warnings, and some typo fixes (Thanks @​lockwo @​ddrous! #468#478, #495, #530)

New Contributors

Full Changelog: patrick-kidger/diffrax@v0.6.0...v0.6.1

Commits
  • 78531fa Bump minimum Equinox version to one that is compatible with latest JAX
  • 825e4e0 Fixed where a nonbatchable check was being called.
  • 1ae1d58 version bump
  • 72deb78 Updated pre-commit to handle jaxtyping update
  • d3490e6 Fixes for JAX 0.4.36 which changes the name of an error.
  • 3c21d15 Updates to the t0==t1 case to handle SubSaveAt(fn=...) and nonstandard dtyp...
  • ebd7980 Save fix for t0==t1 (#494)
  • 965f6b4 Compatibility with JAX 0.4.36, which removes ConcreteArray
  • beadc78 bump doc building pipeline
  • 0cf67d1 small fix of docs in all three and a return type in quicsort
  • Additional commits viewable in compare view

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


Dependabot commands and options

You can trigger Dependabot actions by commenting on this PR:

  • @dependabot rebase will rebase this PR
  • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
  • @dependabot merge will merge this PR after your CI passes on it
  • @dependabot squash and merge will squash and merge this PR after your CI passes on it
  • @dependabot cancel merge will cancel a previously requested merge and block automerging
  • @dependabot reopen will reopen this PR if it is closed
  • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
  • @dependabot show <dependency name> ignore conditions will show all of the ignore conditions of the specified dependency
  • @dependabot ignore <dependency name> major version will close this group update PR and stop Dependabot creating any more for the specific dependency's major version (unless you unignore this specific dependency's major version or upgrade to it yourself)
  • @dependabot ignore <dependency name> minor version will close this group update PR and stop Dependabot creating any more for the specific dependency's minor version (unless you unignore this specific dependency's minor version or upgrade to it yourself)
  • @dependabot ignore <dependency name> will close this group update PR and stop Dependabot creating any more for the specific dependency (unless you unignore this specific dependency or upgrade to it yourself)
  • @dependabot unignore <dependency name> will remove all of the ignore conditions of the specified dependency
  • @dependabot unignore <dependency name> <ignore condition> will remove the ignore condition of the specified dependency and ignore conditions

Updates the requirements on [jax](https://github.com/jax-ml/jax) and [diffrax](https://github.com/patrick-kidger/diffrax) to permit the latest version.

Updates `jax` to 0.4.37
- [Release notes](https://github.com/jax-ml/jax/releases)
- [Changelog](https://github.com/jax-ml/jax/blob/main/CHANGELOG.md)
- [Commits](jax-ml/jax@jax-v0.4.24...jax-v0.4.37)

Updates `diffrax` to 0.6.1
- [Release notes](https://github.com/patrick-kidger/diffrax/releases)
- [Commits](patrick-kidger/diffrax@v0.4.1...v0.6.1)

---
updated-dependencies:
- dependency-name: jax
  dependency-type: direct:production
  dependency-group: jax
- dependency-name: diffrax
  dependency-type: direct:production
  dependency-group: jax
...

Signed-off-by: dependabot[bot] <[email protected]>
@dependabot dependabot bot added dependencies Issue related to libraries we depend on and how we interface with them python Pull requests that update Python code labels Dec 11, 2024
@f0uriest f0uriest added the skip_changelog No need to update changelog on this PR label Dec 11, 2024
@@ -1,6 +1,6 @@
jax >= 0.4.24, <= 0.4.35
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should skip jax 0.4.36

Copy link

codecov bot commented Dec 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.57%. Comparing base (eff8a82) to head (9374e07).
Report is 7 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1461      +/-   ##
==========================================
- Coverage   95.58%   95.57%   -0.02%     
==========================================
  Files          98       98              
  Lines       25156    25156              
==========================================
- Hits        24045    24042       -3     
- Misses       1111     1114       +3     

see 3 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Dec 12, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +2.99 +/- 6.35     | +1.60e-02 +/- 3.41e-02 |  5.52e-01 +/- 2.8e-02  |  5.36e-01 +/- 1.9e-02  |
 test_equilibrium_init_medres            |     -0.77 +/- 1.90     | -3.27e-02 +/- 8.06e-02 |  4.20e+00 +/- 6.5e-02  |  4.24e+00 +/- 4.8e-02  |
 test_equilibrium_init_highres           |     +0.01 +/- 1.39     | +3.44e-04 +/- 7.66e-02 |  5.49e+00 +/- 6.0e-02  |  5.49e+00 +/- 4.8e-02  |
 test_objective_compile_dshape_current   |     +0.12 +/- 6.10     | +4.63e-03 +/- 2.42e-01 |  3.97e+00 +/- 1.8e-02  |  3.97e+00 +/- 2.4e-01  |
 test_objective_compute_dshape_current   |     +0.62 +/- 3.39     | +3.18e-05 +/- 1.75e-04 |  5.19e-03 +/- 1.4e-04  |  5.16e-03 +/- 1.0e-04  |
 test_objective_jac_dshape_current       |     -0.31 +/- 6.20     | -1.31e-04 +/- 2.67e-03 |  4.29e-02 +/- 2.3e-03  |  4.30e-02 +/- 1.4e-03  |
 test_perturb_2                          |     -0.38 +/- 2.03     | -7.69e-02 +/- 4.08e-01 |  2.00e+01 +/- 3.0e-01  |  2.00e+01 +/- 2.8e-01  |
 test_proximal_freeb_jac                 |     -0.00 +/- 1.12     | -3.65e-04 +/- 8.32e-02 |  7.45e+00 +/- 3.8e-02  |  7.45e+00 +/- 7.4e-02  |
 test_solve_fixed_iter                   |     -0.38 +/- 2.00     | -1.29e-01 +/- 6.73e-01 |  3.35e+01 +/- 5.1e-01  |  3.36e+01 +/- 4.4e-01  |
 test_LinearConstraintProjection_build   |     -0.17 +/- 2.90     | -1.77e-02 +/- 3.02e-01 |  1.04e+01 +/- 2.4e-01  |  1.04e+01 +/- 1.8e-01  |
 test_build_transform_fft_midres         |     +0.23 +/- 3.28     | +1.39e-03 +/- 1.95e-02 |  5.97e-01 +/- 1.8e-02  |  5.95e-01 +/- 8.4e-03  |
 test_build_transform_fft_highres        |     -0.17 +/- 1.10     | -1.68e-03 +/- 1.06e-02 |  9.61e-01 +/- 6.5e-03  |  9.62e-01 +/- 8.4e-03  |
 test_equilibrium_init_lowres            |     -1.49 +/- 2.69     | -5.74e-02 +/- 1.03e-01 |  3.79e+00 +/- 2.5e-02  |  3.84e+00 +/- 1.0e-01  |
 test_objective_compile_atf              |     +0.31 +/- 4.34     | +2.52e-02 +/- 3.49e-01 |  8.06e+00 +/- 2.6e-01  |  8.04e+00 +/- 2.3e-01  |
 test_objective_compute_atf              |     -0.04 +/- 1.51     | -6.33e-06 +/- 2.37e-04 |  1.57e-02 +/- 1.7e-04  |  1.57e-02 +/- 1.7e-04  |
 test_objective_jac_atf                  |     -0.97 +/- 2.20     | -1.92e-02 +/- 4.36e-02 |  1.96e+00 +/- 3.0e-02  |  1.98e+00 +/- 3.2e-02  |
 test_perturb_1                          |     +0.57 +/- 1.47     | +8.17e-02 +/- 2.11e-01 |  1.45e+01 +/- 4.9e-02  |  1.44e+01 +/- 2.1e-01  |
 test_proximal_jac_atf                   |     -0.27 +/- 0.73     | -2.20e-02 +/- 6.03e-02 |  8.23e+00 +/- 5.1e-02  |  8.25e+00 +/- 3.2e-02  |
 test_proximal_freeb_compute             |     -0.81 +/- 0.71     | -1.62e-03 +/- 1.41e-03 |  1.97e-01 +/- 9.6e-04  |  1.99e-01 +/- 1.0e-03  |
 test_solve_fixed_iter_compiled          |     -0.83 +/- 0.74     | -1.80e-01 +/- 1.62e-01 |  2.16e+01 +/- 1.1e-01  |  2.18e+01 +/- 1.2e-01  |

@f0uriest f0uriest merged commit 4ed87cb into master Dec 12, 2024
25 checks passed
@dependabot dependabot bot deleted the dependabot/pip/jax-1cca3be8ee branch December 12, 2024 20:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Issue related to libraries we depend on and how we interface with them python Pull requests that update Python code skip_changelog No need to update changelog on this PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Address possible breaking changes from jax==0.4.36
2 participants