# Track Awesome Jax Updates Daily

JAX - A curated list of resources https://github.com/google/jax

🏠 Home · 🔍 Search · 🔥 Feed · 📮 Subscribe · ❤️ Sponsor · 😺 n2cholas/awesome-jax · ⭐ 1K · 🏷️ Computer Science

## Sep 25, 2023

Libraries

- BlackJAX (⭐561) - Library of samplers for JAX.

## Sep 24, 2023

Libraries / New Libraries

- QDax (⭐222) - Quality Diversity optimization in Jax.

## Sep 12, 2023

Libraries / New Libraries

- MaxText (⭐362) - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.

- Pax (⭐267) - A Jax-based machine learning framework for training large scale models.

- Praxis (⭐112) - The layer library for Pax with a goal to be usable by other JAX-based ML projects.

## Sep 11, 2023

Libraries / New Libraries

- purejaxrl (⭐382) - Vectorisable, end-to-end RL algorithms in JAX.

- Lorax (⭐95) - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)

- SCICO (⭐64) - Scientific computational imaging in JAX.

- BrainPy (⭐356) - Brain Dynamics Programming in Python.

- OTT-JAX (⭐369) - Optimal transport tools in JAX.

Tutorials and Blog Posts / NumPyro

- Achieving 4000x Speedups with PureJaxRL - A blog post on how JAX can massively speedup RL training through vectorisation.

## Mar 20, 2023

Libraries / New Libraries

- Kernex (⭐52) - Differentiable stencil decorators in JAX.

## Mar 06, 2023

Libraries / New Libraries

- safejax (⭐22) - Serialize JAX, Flax, Haiku, or Objax model params with 🤗
`safetensors`

.

## Jan 23, 2023

Libraries / New Libraries

- jax-tqdm (⭐36) - Add a tqdm progress bar to JAX scans and loops.

## Jan 15, 2023

Models and Projects / Flax

- GNNs for Solving Combinatorial Optimization Problems (⭐20) - A JAX + Flax implementation of Combinatorial Optimization with Physics-Inspired Graph Neural Networks.

## Jan 02, 2023

Libraries / New Libraries

- SPU (⭐156) - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).

## Dec 27, 2022

Libraries

- Fortuna (⭐717) - AWS library for Uncertainty Quantification in Deep Learning.

## Nov 16, 2022

Libraries / New Libraries

- econpizza (⭐36) - Solve macroeconomic models with hetereogeneous agents using JAX.

## Oct 17, 2022

Models and Projects / NumPyro

- lqg (⭐23) - Official implementation of Bayesian inverse optimal control for linear-quadratic Gaussian problems from the paper
*Putting perception into action with inverse optimal control for continuous psychophysics*

## Sep 05, 2022

Libraries / New Libraries

- JAXFit (⭐27) - Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).

## Sep 03, 2022

Libraries / New Libraries

- Eqxvision (⭐77) - Equinox version of Torchvision.

## Sep 02, 2022

Libraries / New Libraries

- GPJax (⭐309) - Gaussian processes in JAX.

- Jumanji (⭐383) - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.

## Aug 31, 2022

Models and Projects / Flax

- Sharpened Cosine Similarity in JAX by Raphael Pisoni - A JAX/Flax implementation of the Sharpened Cosine Similarity layer.

Tutorials and Blog Posts / NumPyro

- Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit - A tutorial on writing a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax.

- Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar - A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.

- Deep Learning tutorials with JAX+Flax by Phillip Lippe - A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.

## Aug 26, 2022

Books / NumPyro

- Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.

## Aug 09, 2022

Libraries / New Libraries

- jwave (⭐80) - A library for differentiable acoustic simulations

## Jun 28, 2022

Libraries / New Libraries

- Mctx (⭐1.9k) - Monte Carlo tree search algorithms in native JAX.

- KFAC-JAX (⭐154) - Second Order Optimization with Approximate Curvature for NNs.

- TF2JAX (⭐82) - Convert functions/graphs to JAX functions.

## Jun 25, 2022

Libraries / New Libraries

- gymnax (⭐395) - Reinforcement Learning Environments with the well-known gym API.

## Mar 15, 2022

Libraries / New Libraries

- tinygp (⭐244) - The
*tiniest*of Gaussian process libraries in JAX.

## Mar 14, 2022

Libraries / New Libraries

- Einshape (⭐78) - DSL-based reshaping library for JAX and other frameworks.

- ALX (⭐31k) - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in
*ALX: Large Scale Matrix Factorization on TPUs*.

- Diffrax (⭐982) - Numerical differential equation solvers in JAX.

Models and Projects / JAX

- Symbolic Functionals (⭐31k) - Demonstration from
*Evolving symbolic density functionals*.

- TriMap (⭐31k) - Official JAX implementation of
*TriMap: Large-scale Dimensionality Reduction Using Triplets*.

Models and Projects / Flax

- mip-NeRF (⭐795) - Official implementation of
*Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields*.

- RegNeRF (⭐31k) - Official implementation of
*RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs*.

- ARDM (⭐31k) - Official implementation of
*Autoregressive Diffusion Models*.

- D3PM (⭐31k) - Official implementation of
*Structured Denoising Diffusion Models in Discrete State-Spaces*.

- Gumbel-max Causal Mechanisms (⭐31k) - Code for
*Learning Generalized Gumbel-max Causal Mechanisms*, with extra code in GuyLor/gumbel_max_causal_gadgets_part2 (⭐2).

- Latent Programmer (⭐31k) - Code for the ICML 2021 paper
*Latent Programmer: Discrete Latent Codes for Program Synthesis*.

- SNeRG (⭐31k) - Official implementation of
*Baking Neural Radiance Fields for Real-Time View Synthesis*.

- Spin-weighted Spherical CNNs (⭐31k) - Adaptation of
*Spin-Weighted Spherical CNNs*.

- VDVAE (⭐31k) - Adaptation of
*Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images*, original code at openai/vdvae (⭐406).

- MUSIQ (⭐31k) - Checkpoints and model inference code for the ICCV 2021 paper
*MUSIQ: Multi-scale Image Quality Transformer*

- AQuaDem (⭐31k) - Official implementation of
*Continuous Control with Action Quantization from Demonstrations*.

- Combiner (⭐31k) - Official implementation of
*Combiner: Full Attention Transformer with Sparse Computation Cost*.

- Dreamfields (⭐31k) - Official implementation of the ICLR 2022 paper
*Progressive Distillation for Fast Sampling of Diffusion Models*.

- GIFT (⭐31k) - Official implementation of
*Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent*.

- Light Field Neural Rendering (⭐31k) - Official implementation of
*Light Field Neural Rendering*.

## Mar 04, 2022

Libraries / New Libraries

- Neural Network Libraries
- FedJAX (⭐238) - Federated learning in JAX, built on Optax and Haiku.
- Equivariant MLP (⭐227) - Construct equivariant neural network layers.
- jax-resnet (⭐95) - Implementations and checkpoints for ResNet variants in Flax.
- Parallax (⭐157) - Immutable Torch Modules for JAX.

## Feb 24, 2022

Libraries / New Libraries

- SymJAX (⭐109) - Symbolic CPU/GPU/TPU programming.

- mcx (⭐323) - Express & compile probabilistic programs for performant inference.

## Feb 22, 2022

Libraries / New Libraries

- evosax (⭐349) - JAX-Based Evolution Strategies

## Feb 18, 2022

Libraries / New Libraries

- EvoJAX (⭐692) - Hardware-Accelerated Neuroevolution

## Feb 12, 2022

Tutorials and Blog Posts / NumPyro

- Get started with JAX by Aleksa Gordić (⭐453) - A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.

## Feb 11, 2022

Libraries / New Libraries

- PGMax (⭐64) - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.

## Jan 24, 2022

Libraries / New Libraries

- tree-math (⭐151) - Convert functions that operate on arrays into functions that operate on PyTrees.

- jax-models (⭐127) - Implementations of research papers originally without code or code written with frameworks other than JAX.

## Jan 03, 2022

Tutorials and Blog Posts / NumPyro

- Introduction to JAX by Kevin Murphy - Colab that introduces various aspects of the language and applies them to simple ML problems.

## Dec 31, 2021

Models and Projects / Haiku

- Two Player Auction Learning (⭐5) - JAX implementation of the paper
*Auction learning as a two-player game*.

## Nov 13, 2021

Libraries / New Libraries

- JaxDF (⭐98) - Framework for differentiable simulators with arbitrary discretizations.

## Sep 12, 2021

Libraries / New Libraries

- bayex (⭐50) - Bayesian Optimization powered by JAX.

## Aug 10, 2021

Libraries

- Neural Network Libraries
- Flax (⭐4.7k) - Centered on flexibility and clarity.
- Haiku (⭐2.6k) - Focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax (⭐745) - Has an object oriented design similar to PyTorch.
- Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
- Trax (⭐7.7k) - "Batteries included" deep learning library focused on providing solutions for common workloads.
- Jraph (⭐1.2k) - Lightweight graph neural network library.
- Neural Tangents (⭐2.1k) - High-level API for specifying neural networks of both finite and
*infinite*width. - HuggingFace (⭐112k) - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
- Equinox (⭐1.4k) - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
- Scenic (⭐2.5k) - A Jax Library for Computer Vision Research and Beyond.

## Jul 30, 2021

Models and Projects / JAX

- Amortized Bayesian Optimization (⭐31k) - Code related to
*Amortized Bayesian Optimization over Discrete Spaces*.

- Accurate Quantized Training (⭐31k) - Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.

- BNN-HMC (⭐31k) - Implementation for the paper
*What Are Bayesian Neural Network Posteriors Really Like?*.

- JAX-DFT (⭐31k) - One-dimensional density functional theory (DFT) in JAX, with implementation of
*Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics*.

- Robust Loss (⭐31k) - Reference code for the paper
*A General and Adaptive Robust Loss Function*.

Models and Projects / Flax

- Performer (⭐31k) - Flax implementation of the Performer (linear transformer via FAVOR+) architecture.

- JaxNeRF (⭐31k) - Implementation of
*NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*with multi-device GPU/TPU support.

- Big Transfer (BiT) (⭐1.5k) - Implementation of
*Big Transfer (BiT): General Visual Representation Learning*.

- JAX RL (⭐505) - Implementations of reinforcement learning algorithms.

- gMLP - Implementation of
*Pay Attention to MLPs*.

- MLP Mixer - Minimal implementation of
*MLP-Mixer: An all-MLP Architecture for Vision*.

- Distributed Shampoo (⭐31k) - Implementation of
*Second Order Optimization Made Practical*.

- NesT (⭐181) - Official implementation of
*Aggregating Nested Transformers*.

- XMC-GAN (⭐99) - Official implementation of
*Cross-Modal Contrastive Learning for Text-to-Image Generation*.

- FNet (⭐31k) - Official implementation of
*FNet: Mixing Tokens with Fourier Transforms*.

- GFSA (⭐31k) - Official implementation of
*Learning Graph Structure With A Finite-State Automaton Layer*.

- IPA-GNN (⭐31k) - Official implementation of
*Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks*.

- Flax Models (⭐31k) - Collection of models and methods implemented in Flax.

- Protein LM (⭐31k) - Implements BERT and autoregressive models for proteins, as described in
*Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences*and*ProGen: Language Modeling for Protein Generation*.

- Slot Attention (⭐31k) - Reference implementation for
*Differentiable Patch Selection for Image Recognition*.

- Vision Transformer (⭐7.9k) - Official implementation of
*An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale*.

- FID computation (⭐18) - Port of mseitzer/pytorch-fid (⭐2.6k) to Flax.

Models and Projects / Haiku

- AlphaFold (⭐11k) - Implementation of the inference pipeline of AlphaFold v2.0, presented in
*Highly accurate protein structure prediction with AlphaFold*.

- Adversarial Robustness (⭐12k) - Reference code for
*Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples*and*Fixing Data Augmentation to Improve Adversarial Robustness*.

- Bootstrap Your Own Latent (⭐12k) - Implementation for the paper
*Bootstrap your own latent: A new approach to self-supervised Learning*.

- Gated Linear Networks (⭐12k) - GLNs are a family of backpropagation-free neural networks.

- Glassy Dynamics (⭐12k) - Open source implementation of the paper
*Unveiling the predictive power of static structure in glassy systems*.

- MMV (⭐12k) - Code for the models in
*Self-Supervised MultiModal Versatile Networks*.

- Normalizer-Free Networks (⭐12k) - Official Haiku implementation of
*NFNets*.

- NuX (⭐81) - Normalizing flows with JAX.

- OGB-LSC (⭐12k) - This repository contains DeepMind's entry to the PCQM4M-LSC (quantum chemistry) and MAG240M-LSC (academic graph) tracks of the OGB Large-Scale Challenge (OGB-LSC).

- Persistent Evolution Strategies (⭐31k) - Code used for the paper
*Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies*.

- WikiGraphs (⭐12k) - Baseline code to reproduce results in
*WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase*.

Models and Projects / Trax

- Reformer (⭐7.7k) - Implementation of the Reformer (efficient transformer) architecture.

Videos / NumPyro

- JAX, Flax & Transformers 🤗 (⭐112k) - 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.

## Jul 19, 2021

Libraries / New Libraries

- JAXopt (⭐775) - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

- PIX (⭐326) - PIX is an image processing library in JAX, for JAX.

## Jul 14, 2021

Libraries

- NetKet (⭐449) - Machine Learning toolbox for Quantum Physics.

## Jun 13, 2021

Libraries / New Libraries

- exojax (⭐35) - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.

## Jun 12, 2021

Libraries / New Libraries

- CR.Sparse (⭐62) - XLA accelerated algorithms for sparse representations and compressive sensing.

## Jun 09, 2021

Tutorials and Blog Posts / NumPyro

- Writing an MCMC sampler in JAX by Jeremie Coullon - Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.

- How to add a progress bar to JAX scans and loops by Jeremie Coullon - Tutorial on how to add a progress bar to compiled loops in JAX using the
`host_callback`

module.

## Jun 07, 2021

Libraries / New Libraries

- flaxmodels (⭐208) - Pretrained models for Jax/Flax.

## Jun 06, 2021

Libraries / New Libraries

- jaxlie (⭐181) - Lie theory library for rigid body transformations and optimization.

- BRAX (⭐1.8k) - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.

## May 30, 2021

Libraries / New Libraries

- delta PV (⭐50) - A photovoltaic simulator with automatic differentation.

## May 21, 2021

Libraries

- Coax (⭐150) - Turn RL papers into code, the easy way.

## Apr 26, 2021

Libraries

- Distrax (⭐451) - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.

- cvxpylayers (⭐1.6k) - Construct differentiable convex optimization layers.

- TensorLy (⭐1.4k) - Tensor learning made simple.

## Mar 21, 2021

Libraries / New Libraries

- Optimal Transport Tools (⭐215) - Toolbox that bundles utilities to solve optimal transport problems.

## Mar 07, 2021

Libraries

- RLax (⭐1.1k) - Library for implementing reinforcement learning agents.

## Mar 03, 2021

Libraries / New Libraries

- FlaxVision (⭐41) - Flax version of TorchVision.

- Oryx (⭐4k) - Probabilistic programming language based on program transformations.

Videos / NumPyro

- Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.

Tutorials and Blog Posts / NumPyro

- Deterministic ADVI in JAX by Martin Ingram - Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.

- Evolved channel selection by Mat Kelcey - Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.

## Feb 22, 2021

Libraries / New Libraries

- imax (⭐31) - Image augmentations and transformations.

## Feb 08, 2021

Tutorials and Blog Posts / NumPyro

- Exploring hyperparameter meta-loss landscapes with JAX by Luke Metz - Demonstrates how to use JAX to perform inner-loss optimization with SGD and Momentum, outer-loss optimization with gradients, and outer-loss optimization using evolutionary strategies.

## Feb 07, 2021

Tutorials and Blog Posts / NumPyro

- Evolving Neural Networks in JAX by Robert Tjarko Lange - Explores how JAX can power the next generation of scalable neuroevolution algorithms.

## Feb 03, 2021

Tutorials and Blog Posts / NumPyro

- Out of distribution (OOD) detection by Mat Kelcey - Implements different methods for OOD detection.

- Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey (⭐280) - Tutorial demonstrating the infrastructure required to provide custom ops in JAX.

## Jan 29, 2021

Libraries

- NumPyro (⭐1.9k) - Probabilistic programming based on the Pyro library.

- Chex (⭐618) - Utilities to write and test reliable JAX code.

- Optax (⭐1.2k) - Gradient processing and optimization library.

- JAX, M.D. (⭐984) - Accelerated, differential molecular dynamics.

Libraries / New Libraries

- jax-unirep (⭐95) - Library implementing the UniRep model for protein machine learning applications.

- jax-cosmo (⭐141) - Differentiable cosmology library.

Videos / NumPyro

- Introduction to JAX - Simple neural network from scratch in JAX.

Papers / NumPyro

**Compiling machine learning programs via high-level tracing**. Roy Frostig, Matthew James Johnson, Chris Leary.*MLSys 2018*. - White paper describing an early version of JAX, detailing how computation is traced and compiled.

Tutorials and Blog Posts / NumPyro

- Plugging Into JAX by Nick Doiron - Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.

## Jan 23, 2021

Models and Projects / JAX

- Fourier Feature Networks (⭐1.1k) - Official implementation of
*Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains*.

Videos / NumPyro

- NeurIPS 2020: JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.

- JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas - JAX's core design, how it's powering new research, and how you can start using it.

- Bayesian Programming with JAX + NumPyro — Andy Kitchen - Introduction to Bayesian modelling using NumPyro.

- JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury - Presentation of TPU host access with demo.

- Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in
*Deep Implicit Layers*.

Papers / NumPyro

**JAX, M.D.: A Framework for Differentiable Physics**. Samuel S. Schoenholz, Ekin D. Cubuk.*NeurIPS 2020*. - Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.

**Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization**. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath.*arXiv 2020*. - Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.

Tutorials and Blog Posts / NumPyro

- Using JAX to accelerate our research by David Budden and Matteo Hessel - Describes the state of JAX and the JAX ecosystem at DeepMind.

- Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange - Neural network building blocks from scratch with the basic JAX operators.

- Tutorial: image classification with JAX and Flax Linen by 8bitmp3 (⭐21) - Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.

- Meta-Learning in 50 Lines of JAX by Eric Jang - Introduction to both JAX and Meta-Learning.

- Normalizing Flows in 100 Lines of JAX by Eric Jang - Concise implementation of RealNVP.

- Differentiable Path Tracing on the GPU/TPU by Eric Jang - Tutorial on implementing path tracing.

- Ensemble networks by Mat Kelcey - Ensemble nets are a method of representing an ensemble of models as one single logical model.

- Understanding Autodiff with JAX by Srihari Radhakrishna - Understand how autodiff works using JAX.

- From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke - Showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.

## Jan 05, 2021

Libraries / New Libraries

- mpi4jax (⭐301) - Combine MPI operations with your Jax code on CPUs and GPUs.

## Dec 28, 2020

Libraries / New Libraries

- jax-flows (⭐246) - Normalizing flows in JAX.

- sklearn-jax-kernels (⭐38) -
`scikit-learn`

kernel matrices using JAX.

- efax (⭐43) - Exponential Families in JAX.

Models and Projects / JAX

- kalman-jax (⭐86) - Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.

- jaxns (⭐102) - Nested sampling in JAX.

## Dec 21, 2020

Videos / NumPyro

- JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne - JAX intro presentation in
*Program Transformations for Machine Learning*workshop.

## Dec 20, 2020

Community / NumPyro