Jax: jax-v0.4.34 Release

Release date:
October 2, 2024
Previous version:
jax-v0.4.33 (released September 16, 2024)
Magnitude:
0 Diff Delta
Contributors:
0 total committers
Data confidence:
Commits:

Top Contributors in jax-v0.4.34

Could not determine top contributors for this release.

Directory Browser for jax-v0.4.34

All files are compared to previous version, jax-v0.4.33. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • New Functionality

    • This release includes wheels for Python 3.13. Free-threading mode is not yet supported.
    • jax.errors.JaxRuntimeError has been added as a public alias for the formerly private XlaRuntimeError type.
  • Breaking changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
    • array[0] on a pmap result now introduces a reshape (use array[0:1] instead).
    • The per-shard shape (accessable via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit.
    • jax.experimental.host_callback has been deprecated since March 2024, with JAX version 0.4.26. Now we set the default value of the --jax_host_callback_legacy configuration value to True, which means that if your code uses jax.experimental.host_callback APIs, those API calls will be implemented in terms of the new jax.experimental.io_callback API. If this breaks your code, for a very limited time, you can set the --jax_host_callback_legacy to True. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See #20385 for a discussion.
  • Deprecations

    • In jax.numpy.trim_zeros, non-arraylike arguments or arraylike arguments with ndim != 1 are now deprecated, and in the future will result in an error.
    • Internal pretty-printing tools jax.core.pp_* have been removed, after being deprecated in JAX v0.4.30.
    • jax.lib.xla_client.Device is deprecated; use jax.Device instead.
    • jax.lib.xla_client.XlaRuntimeError has been deprecated. Use jax.errors.JaxRuntimeError instead.
  • Deletion:

    • jax.xla_computation is deleted. It has been 3 months since its deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality as jax.xla_computation.
    • jax.xla_computation(fn)(*args, **kwargs) can be replaced with jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
    • You can also use .out_info property of jax.stages.Lowered to get the output information (like tree structure, shape and dtype).
    • For cross-backend lowering, you can replace jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
    • jax.ShapeDtypeStruct no longer accepts the named_shape argument. The argument was only used by xmap which was removed in 0.4.31.
    • jax.tree.map(f, None, non-None), which previously emitted a DeprecationWarning, now raises an error. None is only a tree-prefix of itself. To preserve the current behavior, you can ask jax.tree.map to treat None as a leaf value by writing: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
    • jax.sharding.XLACompatibleSharding has been removed. Please use jax.sharding.Sharding.
  • Bug fixes

    • Fixed a bug where jax.numpy.cumsum would produce incorrect outputs if a non-boolean input was provided and dtype=bool was specified.
    • Edit implementation of jax.numpy.ldexp to get correct gradient.