Testing made fun, in JAX!
Chex is a library of utilities for helping to write reliable JAX code. This includes utils to help:
- Instrument your code (e.g. assertions)
- Debug (e.g. transforming pmaps in vmaps within a context manager)
- Test JAX code across many variants (e.g. jitted vs non-jitted)