Overview
JAX is a Python library for high-performance, accelerator-oriented array computation and composable program transformations, supporting automatic differentiation, vectorization, and compilation (JIT) to GPUs/TPUs.
What is JAX?
- JAX enables fast numerical computing and large-scale machine learning using Python and NumPy code.
- It supports automatic differentiation of native Python/NumPy functions, including reverse-mode (backpropagation) and forward-mode.
- JAX uses XLA to compile and scale programs on TPUs, GPUs, and other accelerators.
- The library allows arbitrary composition of differentiation and compilation.
- JAX is designed as an extensible system for composable function transformations.
Core Transformations in JAX
Automatic Differentiation with grad
jax.grad efficiently computes reverse-mode gradients of functions.
- Differentiation can be nested to compute higher-order derivatives.
- Differentiation works with standard Python control flow, including loops and conditionals.
Compilation with jit
jax.jit uses XLA to compile pure Python functions for efficient execution.
- JIT compilation significantly boosts performance, especially for operations amenable to fusion.
- JIT imposes constraints on Python control flow in compiled functions.
Auto-vectorization with vmap
jax.vmap automatically maps a function across array axes for batching and parallelism.
vmap pushes iteration into primitive operations for performance gains, e.g., matrix-matrix multiplication.
- Combining
vmap with grad and jit enables efficient computation of Jacobians and per-example gradients.
Scaling Computations
- JAX supports scaling from a single device up to thousands, using automatic parallelization, explicit sharding, and manual per-device programming.
- Users can shard data and computation and use collectives for device communication.
- JAX offers tools to inspect data sharding and manage distributed computations.
Installation
- JAX supports Linux, Mac, Windows, and hardware like NVIDIA/AMD/Intel/Apple GPUs and Google TPUs.
- For CPU:
pip install -U jax
- For NVIDIA GPU:
pip install -U "jax[cuda12]"
- For TPU:
pip install -U "jax[tpu]"
- Specialized instructions exist for AMD, Apple, and Intel GPUs.
Key Terms & Definitions
- Automatic Differentiation — Computing derivatives algorithmically for arbitrary code.
- Reverse-mode — Backpropagation; efficient gradient computation for scalar outputs.
- XLA — Accelerated Linear Algebra compiler for optimizing array operations.
- JIT (Just-In-Time compilation) — Compiling code at runtime for speed.
- vmap — Automatic vectorization across a batch dimension.
- Sharding — Splitting data or computation across devices.
Action Items / Next Steps
- Review the JAX documentation for further details and tutorials.
- Install JAX on your target hardware following the platform-specific instructions.
- Explore the Gotchas Notebook for common pitfalls.
- Try out example transformations (grad, jit, vmap) on your own Python functions.