Using JAX to accelerate our research

DeepMind engineers accelerate our research by building tools, scaling up algorithms, and creating challenging virtual and physical worlds for training and testing artificial intelligence (AI) systems. As part of this work, we constantly evaluate new machine learning libraries and frameworks.

Recently, we've found that an increasing number of projects are well served by JAX, a machine learning framework developed by Google Research teams. JAX resonates well with our engineering philosophy and has been widely adopted by our research community over the last year. Here we share our experience of working with JAX, outline why we find it useful for our AI research, and give an overview of the ecosystem we are building to support researchers everywhere.

Why JAX?

JAX is a Python library designed for high-performance numerical computing, especially machine learning research. Its API for numerical functions is based on NumPy, a collection of functions used in scientific computing. Both Python and NumPy are widely used and familiar, making JAX simple, flexible, and easy to adopt.

In addition to its NumPy API, JAX includes an extensible system of composable function transformations that help support machine learning research, including:

  • Differentiation: Gradient-based optimisation is fundamental to ML. JAX natively supports both forward and reverse mode automatic differentiation of arbitrary numerical functions, via function transformations such as grad, hessian, jacfwd and jacrev.
  • Vectorisation: In ML research we often apply a single function to lots of data, e.g. calculating the loss across a batch or evaluating per-example gradients for differentially private learning. JAX provides automatic vectorisation via the vmap transformation that simplifies this form of programming. For example, researchers don't need to reason about batching when implementing new algorithms. JAX also supports large scale data parallelism via the related pmap transformation, elegantly distributing data that is too large for the memory of a single accelerator.
  • JIT-compilation: XLA is used to just-in-time (JIT)-compile and execute JAX programs on GPU and Cloud TPU accelerators. JIT-compilation, together with JAX's NumPy-consistent API, allows researchers with no previous experience in high-performance computing to easily scale to one or many accelerators.

We have found that JAX has enabled rapid experimentation with novel algorithms and architectures and it now underpins many of our recent publications. To learn more please consider joining our JAX Roundtable, Wednesday December 9th 7:00pm GMT, at the NeurIPS virtual conference.

JAX at DeepMind

Supporting state-of-the-art AI research means balancing rapid prototyping and quick iteration with the ability to deploy experiments at a scale traditionally associated with production systems. What makes these kinds of projects particularly challenging is that the research landscape evolves rapidly and is difficult to forecast. At any point, a new research breakthrough may, and regularly does, change the trajectory and requirements of entire teams. Within this ever-changing landscape, a core responsibility of our engineering team is to make sure that the lessons learned and the code written for one research project is reused effectively in the next.

One approach that has proven successful is modularisation: we extract the most important and critical building blocks developed in each research project into well tested and efficient components. This empowers researchers to focus on their research while also benefiting from code reuse, bug fixes and performance improvements in the algorithmic ingredients implemented by our core libraries. We’ve also found that it’s important to make sure that each library has a clearly defined scope and to ensure that they’re interoperable but independent. Incremental buy-in, the ability to pick and choose features without being locked into others, is critical to providing maximum flexibility for researchers and always supporting them in choosing the right tool for the job.

Other considerations that have gone into the development of our JAX Ecosystem include making sure that it remains consistent (where possible) with the design of our existing TensorFlow libraries (e.g. Sonnet and TRFL). We’ve also aimed to build components that (where relevant) match their underlying mathematics as closely as possible, to be self-descriptive and minimise mental hops "from paper to code". Finally, we’ve chosen to open source our libraries to facilitate sharing of research outputs and to encourage the broader community to explore the JAX Ecosystem.

Our Ecosystem today

Haiku 

The JAX programming model of composable function transformations can make dealing with stateful objects complicated, e.g. neural networks with trainable parameters. Haiku is a neural network library that allows users to use familiar object-oriented programming models while harnessing the power and simplicity of JAX's pure functional paradigm.

Haiku is actively used by hundreds of researchers across DeepMind and Google, and has already found adoption in several external projects (e.g. Coax, DeepChem, NumPyro). It builds on the API for Sonnet, our module-based programming model for neural networks in TensorFlow, and we’ve aimed to make porting from Sonnet to Haiku as simple as possible.

Find out more on GitHub


Optax 

Gradient-based optimisation is fundamental to ML. Optax provides a library of gradient transformations, together with composition operators (e.g. chain) that allow implementing many standard optimisers (e.g. RMSProp or Adam) in just a single line of code.

The compositional nature of Optax naturally supports recombining the same basic ingredients in custom optimisers. It additionally offers a number of utilities for stochastic gradient estimation and second order optimisation.

Many Optax users have adopted Haiku but in line with our incremental buy-in philosophy, any library representing parameters as JAX tree structures is supported (e.g. Elegy, Flax and Stax). Please see here for more information on this rich ecosystem of JAX libraries.

Find out more on GitHub


RLax

Many of our most successful projects are at the intersection of deep learning and reinforcement learning (RL), also known as deep reinforcement learning. RLax is a library that provides useful building blocks for constructing RL agents.

The components in RLax cover a broad spectrum of algorithms and ideas: TD-learning, policy gradients, actor critics, MAP, proximal policy optimisation, non-linear value transformation, general value functions, and a number of exploration methods.

Although some introductory example agents are provided, RLax is not intended as a framework for building and deploying full RL agent systems. One example of a fully-featured agent framework that builds upon RLax components is Acme.

Find out more on GitHub


Chex

Testing is critical to software reliability and research code is no exception. Drawing scientific conclusions from research experiments requires being confident in the correctness of your code. Chex is a collection of testing utilities used by library authors to verify the common building blocks are correct and robust and by end-users to check their experimental code.

Chex provides an assortment of utilities including JAX-aware unit testing, assertions of properties of JAX datatypes, mocks and fakes, and multi-device test environments. Chex is used throughout DeepMind’s JAX Ecosystem and by external projects such as Coax and MineRL.

Find out more on GitHub


Jraph

Graph neural networks (GNNs) are an exciting area of research with many promising applications. See, for instance, our recent work on traffic prediction in Google Maps and our work on physics simulation. Jraph (pronounced "giraffe") is a lightweight library to support working with GNNs in JAX.

Jraph provides a standardised data structure for graphs, a set of utilities for working with graphs, and a 'zoo' of easily forkable and extensible graph neural network models. Other key features include: batching of GraphTuples that efficiently leverage hardware accelerators, JIT-compilation support of variable-shaped graphs via padding and masking, and losses defined over input partitions. Like Optax and our other libraries, Jraph places no constraints on the user's choice of a neural network library.

Learn more about using the library from our rich collection of examples.

Find out more on GitHub

Our JAX Ecosystem is constantly evolving and we encourage the ML research community to explore our libraries and the potential of JAX to accelerate their own research.


Citing the DeepMind JAX Ecosystem

If you find the DeepMind JAX Ecosystem useful for your work, please use this citation (hosted on GitHub).