Open-Awesome
CategoriesAlternativesStacksSelf-HostedExplore
Open-Awesome

© 2026 Open-Awesome. Curated for the developer elite.

TermsPrivacyAboutGitHubRSS
  1. Home
  2. Stacks
  3. JAX
J

JAX

Framework
113 projects537.7k total stars87.9k total forks3 languages

Open-source projects built with JAX

There are currently 113 open-source projects built with JAX, with a combined total of 537.7k GitHub stars. The most common language among these projects is Python.

Showing 116 open-source projects · page 2 of 4

Community-curated · Updated weekly · 100% open source

Found a gem we're missing?

Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.

Submit a projectStar on GitHub
Jraph
Jraphgoogle-deepmind/jraph

A lightweight library for building and training graph neural networks using JAX.

1.5k105Python
2 years ago
RLax
RLaxdeepmind/rlax

A JAX-based library providing reinforcement learning building blocks for implementing agents, supporting both on-policy and off-policy learning.

1.4k101Python
2 months ago
JAX, M.D.
JAX, M.D.google/jax-md

A differentiable, hardware-accelerated molecular dynamics simulation framework built on JAX for computational physics and materials science.

1.4k241Jupyter Notebook
6 days ago
BlackJAX
BlackJAXblackjax-devs/blackjax

A fast, modular Bayesian inference library for JAX, providing composable samplers for CPU and GPU.

1.1k143Python
8 days ago
purejaxrl
purejaxrlluchris429/purejaxrl

High-performance, end-to-end reinforcement learning implementations fully written in JAX for massive parallelization on GPUs.

1.1k87Python
1 year ago
SKRL
SKRLToni-SM/skrl

A modular Python library for Reinforcement Learning with support for PyTorch, JAX, NVIDIA Warp, and multiple environment interfaces.

1.1k144Python
28 days ago
KerasCV
KerasCVkeras-team/keras-cv

A library of modular computer vision components built on Keras 3, supporting TensorFlow, JAX, and PyTorch backends.

1.1k324Python
3 months ago
JAXopt
JAXoptgoogle/jaxopt

Hardware-accelerated, batchable, and differentiable optimization algorithms implemented in JAX for machine learning research.

1.0k72Python
4 days ago
KerasNLP
KerasNLPkeras-team/keras-nlp

A pretrained modeling library for Keras 3 offering simple, flexible, and fast access to models for text, image, and audio tasks.

984341Python
5 days ago
Dynamax
Dynamaxprobml/dynamax

A Python library for probabilistic state space modeling and inference, built on JAX.

970113Python
5 months ago
EvoJAX
EvoJAXgoogle/evojax

A scalable, hardware-accelerated neuroevolution toolkit built on JAX for parallel training across TPUs/GPUs.

948110Jupyter Notebook
1 year ago
Chex
Chexdeepmind/chex

A library of utilities for writing and testing reliable JAX code, including assertions, debugging tools, and test variants.

94469Python
1 month ago
mip-NeRF
mip-NeRFgoogle/mipnerf

Official JAX implementation of Mip-NeRF, a multiscale neural radiance field model for anti-aliased novel view synthesis.

939112Python
3 years ago
gymnax
gymnaxRobertTLange/gymnax

A JAX-based library providing accelerated reinforcement learning environments with full compatibility to the classic gym API.

90097Python
2 months ago
Nucleotide Transformer
Nucleotide Transformerinstadeepai/nucleotide-transformer

A collection of transformer-based foundation models for genomics and transcriptomics, enabling tasks like sequence analysis, functional prediction, and conversational DNA exploration.

88095Jupyter Notebook
3 months ago
Jumanji
Jumanjiinstadeepai/jumanji

A diverse suite of scalable reinforcement learning environments written in JAX for hardware-accelerated research.

84097Python
5 days ago
Objax
Objaxgoogle/objax

An object-oriented machine learning framework built on JAX, designed for simplicity and readability in research.

77372Python
2 years ago
evosax
evosaxRobertTLange/evosax

A comprehensive, high-performance library implementing 30+ Evolution Strategies in JAX for scalable optimization on modern hardware.

76262Python
2 months ago
JAX RL
JAX RLikostrikov/jax-rl

JAX (Flax) implementations of reinforcement learning algorithms for continuous action spaces, designed for research.

75775Jupyter Notebook
3 years ago
OTT-JAX
OTT-JAXott-jax/ott

A JAX-powered library for solving large-scale optimal transport problems, including matching, barycenters, and neural approximations.

741135Python
10 days ago
Levanter
Levanterstanford-crfm/levanter

A JAX-based framework for training large language models with a focus on legibility, scalability, and reproducibility.

709120Python
4 months ago
BrainPy
BrainPybrainpy/BrainPy

A flexible, efficient, and extensible JIT-compiled framework for computational neuroscience and brain-inspired computation.

688108Python
15 days ago
NetKet
NetKetnetket/netket

A Python library built on JAX for studying many-body quantum systems using neural networks and machine learning.

680215Python
4 days ago
Distrax
Distraxdeepmind/distrax

A JAX-native library of probability distributions and bijectors, reimplementing a subset of TensorFlow Probability with emphasis on readability and extensibility.

63540Python
26 days ago
GPJax
GPJaxthomaspinder/GPJax

A low-level Gaussian process framework in JAX and Flax, designed for maximum flexibility and close alignment with mathematical notation.

62774Python
1 day ago
scikit-fem
scikit-femkinnala/scikit-fem

A pure Python library for finite element assembly, transforming bilinear forms into sparse matrices and linear forms into vectors.

625105Python
3 days ago
Pgx
Pgxsotetsuk/pgx

A collection of GPU-accelerated parallel game simulators for reinforcement learning, built with JAX.

61549Python
1 year ago
Optimistix
Optimistixpatrick-kidger/optimistix

A JAX library for nonlinear optimization including root finding, minimization, fixed points, and least squares.

58851Python
26 days ago
Pax
Paxgoogle/paxml

A JAX-based machine learning framework for configuring and training large-scale models with high efficiency on TPUs and GPUs.

55472Python
4 days ago
XLB
XLBAutodesk/XLB

A differentiable, massively parallel Lattice Boltzmann library in Python for physics-based machine learning and fluid dynamics simulations.

47781Python
10 days ago
PIX
PIXdeepmind/dm_pix

An image processing library built on JAX, designed to be optimized and parallelized with JAX transformations.

43928Python
6 days ago
Tequila
Tequilaaspuru-guzik-group/tequila

A high-level Python framework for formulating, optimizing, and executing variational quantum algorithms on simulators and real hardware.

432130Python
11 days ago
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackey
Extending JAX with custom C++ and CUDA code by Dan Foreman-Mackeydfm/extending-jax

A tutorial demonstrating how to extend JAX with custom C++ and CUDA operations for high-performance computing.

40323Python
1 year ago
EasyDeL
EasyDeLerfanzar/EasyDeL

A JAX-based framework for streamlined training, fine-tuning, and high-performance serving of large language and multimodal models.

36550Python
1 day ago
QDax
QDaxadaptive-intelligent-robotics/QDax

A hardware-accelerated Python library for running Quality-Diversity and neuroevolution algorithms in minutes instead of days.

35655Python
7 months ago
tinygp
tinygpdfm/tinygp

An extremely lightweight Gaussian Process library for Python built on JAX with GPU acceleration and automatic differentiation.

34734Python
7 days ago
1
2
3
4