Jax: jax-v0.5.1 Release

Release date:
February 26, 2025
Previous version:
jax-v0.5.0 (released February 15, 2025)
Magnitude:
60,641 Diff Delta
Contributors:
85 total committers
Data confidence:
Commits:

664 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored February 10, 2025
Authored November 22, 2024
Authored February 5, 2025
Authored January 10, 2025
Authored January 5, 2025
Authored February 13, 2025
Authored January 23, 2025
Authored February 7, 2025
Authored January 28, 2025

Top Contributors in jax-v0.5.1

yashk2810
jakevdp
gnecula
apaszke
dfm
dimitar-asenov
a-googler
hawkinsp
superbobry
jburnim

Directory Browser for jax-v0.5.1

All files are compared to previous version, jax-v0.5.0. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • New Features

    • Added an experimental jax.experimental.custom_dce.custom_dce decorator to support customizing the behavior of opaque functions under JAX-level dead code elimination (DCE). See #25956 for more details.
    • Added low-level reduction APIs in {mod}jax.lax: jax.lax.reduce_sum, jax.lax.reduce_prod, jax.lax.reduce_max, jax.lax.reduce_min, jax.lax.reduce_and, jax.lax.reduce_or, and jax.lax.reduce_xor.
    • jax.lax.linalg.qr, and jax.scipy.linalg.qr, now support column-pivoting on CPU and GPU. See #20282 and #25955 for more details.
  • Changes

    • JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES now work as env vars. Before they could only be specified via jax.config or flags.
    • JAX_CPU_COLLECTIVES_IMPLEMENTATION now defaults to 'gloo', meaning multi-process CPU communication works out-of-the-box.
    • The jax[tpu] TPU extra no longer depends on the libtpu-nightly package. This package may safely be removed if it is present on your machine; JAX now uses libtpu instead.
  • Deprecations

    • The internal function linear_util.wrap_init and the constructor core.Jaxpr now must take a non-empty core.DebugInfo kwarg. For a limited time, a DeprecationWarning is printed if jax.extend.linear_util.wrap_init is used without debugging info. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail.
  • Bug fixes

    • TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer (from around 17s to around 8s). If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'). We hope to improve this further in future releases.
    • Persistent compilation cache no longer writes access time file if JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU eviction policy isn't enabled. This should improve performance when using the cache with large-scale network storage.