Jax: jax-v0.4.36 Release

Release date:
December 19, 2024
Previous version:
jax-v0.4.35 (released October 22, 2024)
Magnitude:
72,394 Diff Delta
Contributors:
99 total committers
Data confidence:
Commits:

886 Commits in this Release

Ordered by the degree to which they evolved the repo in this version.

Authored November 22, 2024
Authored November 2, 2024
Authored November 5, 2024
Authored November 14, 2024
Authored November 7, 2024
Authored September 19, 2024
Authored November 19, 2024
Authored November 5, 2024
Authored December 2, 2024
Authored December 1, 2024
Authored November 7, 2024
Authored November 1, 2024

Top Contributors in jax-v0.4.36

dougalm
jakevdp
apaszke
yashk2810
a-googler
superbobry
dfm
nitins17
hawkinsp
cperivol

Directory Browser for jax-v0.4.36

All files are compared to previous version, jax-v0.4.35. Click here to browse diffs between other versions.

Loading File Browser...

Release Notes Published

  • Breaking Changes

    • This release lands "stackless", an internal change to JAX's tracing machinery. We made trace dispatch purely a function of context rather than a function of both context and data. This let us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind, and so on. The change should only affect users that use JAX internals.

    If you do use JAX internals then you may need to update your code (see https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the config.jax_data_dependent_tracing_fallback flag as a workaround, and if you need help updating your code then please file a bug. * jax.experimental.jax2tf.convert with native_serialization=False or with enable_xla=False have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. jax2tf with native serialization will still be supported. * In jax.interpreters.xla, the xb, xc, and xe symbols have been removed after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge, xc = jax.lib.xla_client, and xe = jax.lib.xla_extension. * The deprecated module jax.experimental.export has been removed. It was replaced by jax.export in JAX v0.4.30. See the migration guide for information on migrating to the new API. * The initial argument to jax.nn.softmax and jax.nn.log_softmax has been removed, after being deprecated in v0.4.27. * Calling np.asarray on typed PRNG keys (i.e. keys produced by jax.random.key) now raises an error. Previously, this returned a scalar object array. * The following deprecated methods and functions in jax.export have been removed: * jax.export.DisabledSafetyCheck.shape_assertions: it had no effect already. * jax.export.Exported.lowering_platforms: use platforms. * jax.export.Exported.mlir_module_serialization_version: use calling_convention_version. * jax.export.Exported.uses_shape_polymorphism: use uses_global_constants. * the lowering_platforms kwarg for jax.export.export: use platforms instead. * The kwargs symbolic_scope and symbolic_constraints from jax.export.symbolic_args_specs have been removed. They were deprecated in June 2024. Use scope and constraints instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a TypeError. * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and replaces previous build.py usage. Run python build/build.py --help for more details. Brief overview of the new subcommand options: * build: Builds JAX wheel packages. For e.g., python build/build.py build --wheels=jaxlib,jax-cuda-pjrt * requirements_update: Updates requirements_lock.txt files. * jax.scipy.linalg.toeplitz now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call jax.numpy.ravel on the function inputs. * jax.scipy.special.gamma and jax.scipy.special.gammasgn now return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827. * jax.clear_backends was removed after being deprecated in v0.4.26. * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the disabled_checks parameter. See more details in the documentation.

  • New Features

    • jax.jit got a new compiler_options: dict[str, Any] argument, for passing compilation options to XLA. For the moment it's undocumented and may be in flux.
    • jax.tree_util.register_dataclass now allows metadata fields to be declared inline via dataclasses.field. See the function documentation for examples.
    • Added jax.numpy.put_along_axis.
    • jax.lax.linalg.eig and the related jax.numpy functions (jax.numpy.linalg.eig and jax.numpy.linalg.eigvals) are now supported on GPU. See #24663 for more details.
    • Added two new configuration flags, jax_exec_time_optimization_effort and jax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
  • Bug fixes

    • Fixed a bug where the GPU implementations of LU and QR decomposition would result in an indexing overflow for batch sizes close to int32 max. See #24843 for more details.
  • Deprecations

    • jax.lib.xla_extension.ArrayImpl and jax.lib.xla_client.ArrayImpl are deprecated; use jax.Array instead.
    • jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.