Jax: jax-v0.5.0 Release

Release date:
February 15, 2025
Previous version:
jax-v0.4.38 (released January 16, 2025)
Magnitude:
40,607 Diff Delta
Contributors:
63 total committers
Data confidence:
Commits:

472 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored December 4, 2024
Authored December 2, 2024
Authored January 9, 2025
Authored January 5, 2025
Authored December 20, 2024
Authored October 2, 2024
Authored December 18, 2024
Authored December 17, 2024
Authored January 9, 2025
Authored December 18, 2024
Authored January 9, 2025

Top Contributors in jax-v0.5.0

jakevdp
a-googler
yashk2810
dfm
hawkinsp
apaszke
bchetioui
tlongeri
superbobry
justinjfu

Directory Browser for jax-v0.5.0

All files are compared to previous version, jax-v0.4.38. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

As of this release, JAX now uses effort-based versioning. Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the "meso" version of JAX to signify this.

  • Breaking changes

    Two key factors motivated this decision: * The Mac x86 build (only) has a number of test failures and crashes. We would prefer to ship no release than a broken release. * Mac x86 hardware is end-of-life and cannot be easily obtained for developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to.

    We are open to readding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again.

  • Changes:

    • The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025.
    • The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum supported version until June 2025.
    • jax.numpy.einsum now defaults to optimize='auto' rather than optimize='optimal'. This avoids exponentially-scaling trace-time in the case of many arguments (#25214).
    • jax.numpy.linalg.solve no longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, use solve(a, b[..., None]).squeeze(-1).
  • New Features

    • jax.numpy.fft.fftn, jax.numpy.fft.rfftn, jax.numpy.fft.ifftn, and jax.numpy.fft.irfftn now support transforms in more than 3 dimensions, which was previously the limit. See #25606 for more details.
    • Support added for user defined state in the FFI via the new jax.ffi.register_ffi_type_id function.
    • The AOT lowering .as_text() method now supports the debug_info option to include debugging information, e.g., source location, in the output.
  • Deprecations

    • From jax.interpreters.xla, abstractify and pytype_aval_mappings are now deprecated, having been replaced by symbols of the same name in jax.core.
    • jax.scipy.special.lpmn and jax.scipy.special.lpmn_values are deprecated, following their deprecation in SciPy v1.15.0. There are no plans to replace these deprecated functions with new APIs.
    • The jax.extend.ffi submodule was moved to jax.ffi, and the previous import path is deprecated.
  • Deletions

    • jax_enable_memories flag has been deleted and the behavior of that flag is on by default.
    • From jax.lib.xla_client, the previously-deprecated Device and XlaRuntimeError symbols have been removed; instead use jax.Device and jax.errors.JaxRuntimeError respectively.
    • The jax.experimental.array_api module has been removed after being deprecated in JAX v0.4.32. Since that release, jax.numpy supports the array API directly.