Google JAX

Google JAX is a machine learning framework for transforming numerical functions to be used in Python. It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. The primary functions of JAX are:


 * 1) grad: automatic differentiation
 * 2) jit: compilation
 * 3) vmap: auto-vectorization
 * 4) pmap: SPMD programming

grad
The code below demonstrates the grad function's automatic differentiation.

The final line should outputː

jit
The code below demonstrates the jit function's optimization through fusion.

The computation time for jit_cube (line no. 17) should be noticeably shorter than that for cube (line no. 16). Increasing the values on line no. 10, will increase the difference.

vmap
The code below demonstrates the vmap function's vectorization.

The GIF on the right of this section illustrates the notion of vectorized addition.



pmap
The code below demonstrates the pmap function's parallelization for matrix multiplication.

The final line should print the valuesː

Libraries using JAX
Several python libraries use JAX as a backend, including:


 * Flax, a high level neural network library initially developed by Google Brain.
 * Equinox, a library that extends flax struct module to build neural networks as PyTrees.


 * Optax, a library for gradient processing and optimisation developed by DeepMind.


 * RLax, a library for developing reinforcement learning agents developed by DeepMind.
 * jraph, a library for graph neural networks, developed by DeepMind.
 * jaxtyping, a library for adding type annotations for the shape and data type ("dtype") of arrays or tensors.

Some R libraries use JAX as a backend as well, including:
 * fastrerandomize, a library that uses the linear-algebra optimized compiler in JAX to speed up selection of balanced randomizations in a design of experiments procedure known as rerandomization.