Open-Awesome
CategoriesAlternativesStacksSelf-HostedExplore
Open-Awesome

© 2026 Open-Awesome. Curated for the developer elite.

TermsPrivacyAboutGitHubRSS
  1. Home
  2. Stacks
  3. JAX
J

JAX

Framework
113 projects537.7k total stars87.9k total forks3 languages

Open-source projects built with JAX

There are currently 113 open-source projects built with JAX, with a combined total of 537.7k GitHub stars. The most common language among these projects is Python.

Showing 108 open-source projects · page 3 of 3

Community-curated · Updated weekly · 100% open source

Found a gem we're missing?

Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.

Submit a projectStar on GitHub
jaxlie
jaxliebrentyi/jaxlie

A JAX library implementing Lie groups for rigid body transformations in computer vision and robotics.

33218Python
1 year ago
mcx
mcxrlouf/mcx

A JAX-powered probabilistic programming library focused on performant sampling methods for Bayesian inference on CPU, GPU, and TPU.

33216Python
2 years ago
KFAC-JAX
KFAC-JAXdeepmind/kfac-jax

A JAX library for second-order optimization of neural networks using the K-FAC curvature approximation algorithm.

32429Python
24 days ago
FDTDX
FDTDXymahlau/fdtdx

An efficient open-source Python package for 3D photonic nanostructure simulation and design using GPU-accelerated FDTD with automatic differentiation.

30158Python
1 day ago
dynamiqs
dynamiqsdynamiqs/dynamiqs

A Python library for GPU-accelerated and differentiable quantum systems simulation built with JAX.

29249Python
1 day ago
Equivariant MLP
Equivariant MLPmfinzi/equivariant-MLP

A JAX library for automatically generating equivariant neural network layers for arbitrary symmetry groups via constraint solving.

28426Jupyter Notebook
3 years ago
FedJAX
FedJAXgoogle/fedjax

A JAX-based library for federated learning simulations that emphasizes ease-of-use in research.

27242Python
1 month ago
flaxmodels
flaxmodelsmatthias-wright/flaxmodels

A collection of pretrained deep learning models (StyleGAN2, GPT2, VGG, ResNet) for the Jax/Flax ecosystem.

26528Python
1 year ago
jax-cosmo
jax-cosmoDifferentiableUniverseInitiative/jax_cosmo

A differentiable cosmology library built with JAX for automatic differentiation of cosmological calculations.

24048Python
11 months ago
jaxns
jaxnsJoshuaalbert/jaxns

A JAX-based probabilistic programming framework using nested sampling for fast Bayesian inference and evidence computation.

23620Python
22 days ago
flowjax
flowjaxdanielward27/flowjax

A JAX library for distributions, bijections, and normalizing flows implemented as Equinox modules.

22923Python
3 months ago
torchax
torchaxgoogle/torchax

A PyTorch frontend for JAX that enables running PyTorch code on TPUs and provides seamless PyTorch-JAX interoperability.

22631Python
17 days ago
Optimal Transport Tools
Optimal Transport Toolsgoogle-research/ott

A Python toolbox for solving optimal transport problems with JAX-powered computational efficiency.

21417Python
4 years ago
tree-math
tree-mathgoogle/tree-math

A Python library that enables mathematical operations on JAX pytrees, allowing numerical algorithms to work directly with complex nested data structures.

2109Python
1 year ago
jwave
jwaveucl-bug/jwave

A JAX-based research framework for differentiable and parallelizable acoustic simulations, running on CPU, GPU, and TPU.

20933Python
2 months ago
NesT
NesTgoogle-research/nested-transformer

A vision transformer architecture that aggregates nested local transformers on image blocks for better accuracy, data efficiency, and convergence.

20427Jupyter Notebook
3 months ago
Praxis
Praxisgoogle/praxis

A layer library for JAX-based machine learning projects, optimized for large-scale ML.

19643Python
1 month ago
NAVIX
NAVIXepignatelli/navix

A JAX-powered reimplementation of MiniGrid offering over 1000x speedup for reinforcement learning experiments.

17021Python
7 months ago
kvax
kvaxnebius/kvax

A FlashAttention 2 implementation for JAX with block-wise document mask optimization and context parallelism for efficient long-sequence training.

1679Python
6 months ago
SCICO
SCICOlanl/scico

A Python package built on JAX for solving inverse problems in scientific imaging using optimization and prior models.

16224Python
1 day ago
jax-models
jax-modelsDarshanDeshpande/jax-models

Unofficial JAX/Flax implementations of deep learning research papers for vision transformers and other architectures.

16210Python
4 years ago
Parallax
Parallaxsrush/parallax

A pure, immutable module system for JAX that replaces PyTorch-style imperative coding with declarative parameter trees.

1544Python
6 years ago
Lorax
Loraxdavisyoshida/lorax

A JAX transform that implements LoRA (Low-Rank Adaptation) for efficient fine-tuning of large models with minimal memory overhead.

1446Python
2 years ago
Spyx
Spyxkmheckel/spyx

A compact spiking neural network library built on JAX and Haiku, offering high-performance training via surrogate gradient descent and neuroevolution.

13514Jupyter Notebook
1 month ago
JaxDF
JaxDFucl-bug/jaxdf

A JAX-based framework for building differentiable numerical simulators with arbitrary discretizations for physical systems.

13413Python
28 days ago
SymJAX
SymJAXSymJAX/SymJAX

A symbolic programming library built on JAX for concise, explicit, and optimized machine learning computations.

1315Python
3 years ago
jax-tqdm
jax-tqdmjeremiecoullon/jax-tqdm

Add tqdm progress bars to JAX scans and loops using decorators, enabling side-effect-free progress tracking.

1308Python
13 days ago
TF2JAX
TF2JAXdeepmind/tf2jax

An experimental library that converts TensorFlow functions and graphs into JAX functions for reuse and fine-tuning within JAX codebases.

12420Python
1 month ago
jax-resnet
jax-resnetn2cholas/jax-resnet

Flax implementations and pretrained checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX.

1208Python
4 years ago
Eqxvision
Eqxvisionpaganpasta/eqxvision

A Python package providing popular computer vision model architectures built with Equinox for JAX.

11215Python
1 year ago
bayex
bayexalonfnt/bayex

A lightweight Bayesian optimization library built on JAX for efficient optimization of expensive-to-evaluate functions.

1085Python
1 year ago
jax-unirep
jax-unirepElArkk/jax-unirep

A performant JAX reimplementation of the UniRep model for generating protein sequence representations.

10731TeX
1 year ago
kalman-jax
kalman-jaxAaltoML/kalman-jax

A JAX-based framework for approximate inference in Markov Gaussian processes using iterated Kalman smoothing.

10313Jupyter Notebook
2 years ago
XMC-GAN
XMC-GANgoogle-research/xmcgan_image_generation

Official JAX implementation of XMC-GAN for text-to-image generation using cross-modal contrastive learning.

9914Python
16 days ago
NuX
NuXInformation-Fusion-Lab-Umass/NuX

A JAX library for building, training, and evaluating normalizing flows for probabilistic modeling.

874Python
2 years ago
Rockpool
Rockpoolsynsense/rockpool

A Python library for building, training, and deploying spiking neural networks with support for multiple simulation backends and neuromorphic hardware.

8415Python
2 months ago
1
2
3