🚀

JAX Overview and Core Features

Jul 9, 2025

Overview

JAX is a Python library for high-performance, accelerator-oriented array computation and composable program transformations, supporting automatic differentiation, vectorization, and compilation (JIT) to GPUs/TPUs.

What is JAX?

  • JAX enables fast numerical computing and large-scale machine learning using Python and NumPy code.
  • It supports automatic differentiation of native Python/NumPy functions, including reverse-mode (backpropagation) and forward-mode.
  • JAX uses XLA to compile and scale programs on TPUs, GPUs, and other accelerators.
  • The library allows arbitrary composition of differentiation and compilation.
  • JAX is designed as an extensible system for composable function transformations.

Core Transformations in JAX

Automatic Differentiation with grad

  • jax.grad efficiently computes reverse-mode gradients of functions.
  • Differentiation can be nested to compute higher-order derivatives.
  • Differentiation works with standard Python control flow, including loops and conditionals.

Compilation with jit

  • jax.jit uses XLA to compile pure Python functions for efficient execution.
  • JIT compilation significantly boosts performance, especially for operations amenable to fusion.
  • JIT imposes constraints on Python control flow in compiled functions.

Auto-vectorization with vmap

  • jax.vmap automatically maps a function across array axes for batching and parallelism.
  • vmap pushes iteration into primitive operations for performance gains, e.g., matrix-matrix multiplication.
  • Combining vmap with grad and jit enables efficient computation of Jacobians and per-example gradients.

Scaling Computations

  • JAX supports scaling from a single device up to thousands, using automatic parallelization, explicit sharding, and manual per-device programming.
  • Users can shard data and computation and use collectives for device communication.
  • JAX offers tools to inspect data sharding and manage distributed computations.

Installation

  • JAX supports Linux, Mac, Windows, and hardware like NVIDIA/AMD/Intel/Apple GPUs and Google TPUs.
  • For CPU: pip install -U jax
  • For NVIDIA GPU: pip install -U "jax[cuda12]"
  • For TPU: pip install -U "jax[tpu]"
  • Specialized instructions exist for AMD, Apple, and Intel GPUs.

Key Terms & Definitions

  • Automatic Differentiation — Computing derivatives algorithmically for arbitrary code.
  • Reverse-mode — Backpropagation; efficient gradient computation for scalar outputs.
  • XLA — Accelerated Linear Algebra compiler for optimizing array operations.
  • JIT (Just-In-Time compilation) — Compiling code at runtime for speed.
  • vmap — Automatic vectorization across a batch dimension.
  • Sharding — Splitting data or computation across devices.

Action Items / Next Steps

  • Review the JAX documentation for further details and tutorials.
  • Install JAX on your target hardware following the platform-specific instructions.
  • Explore the Gotchas Notebook for common pitfalls.
  • Try out example transformations (grad, jit, vmap) on your own Python functions.