Showing 36 of 113 projects
A Python library for GPU-accelerated and differentiable quantum systems simulation built with JAX.
A JAX library for automatically generating equivariant neural network layers for arbitrary symmetry groups via constraint solving.
A JAX-based library for federated learning simulations that emphasizes ease-of-use in research.
A collection of pretrained deep learning models (StyleGAN2, GPT2, VGG, ResNet) for the Jax/Flax ecosystem.
A differentiable cosmology library built with JAX for automatic differentiation of cosmological calculations.
A JAX-based probabilistic programming framework using nested sampling for fast Bayesian inference and evidence computation.
A JAX library for distributions, bijections, and normalizing flows implemented as Equinox modules.
A PyTorch frontend for JAX that enables running PyTorch code on TPUs and provides seamless PyTorch-JAX interoperability.
A Python toolbox for solving optimal transport problems with JAX-powered computational efficiency.
A Python library that enables mathematical operations on JAX pytrees, allowing numerical algorithms to work directly with complex nested data structures.
A JAX-based research framework for differentiable and parallelizable acoustic simulations, running on CPU, GPU, and TPU.
A vision transformer architecture that aggregates nested local transformers on image blocks for better accuracy, data efficiency, and convergence.
A layer library for JAX-based machine learning projects, optimized for large-scale ML.
A JAX-powered reimplementation of MiniGrid offering over 1000x speedup for reinforcement learning experiments.
A FlashAttention 2 implementation for JAX with block-wise document mask optimization and context parallelism for efficient long-sequence training.
A Python package built on JAX for solving inverse problems in scientific imaging using optimization and prior models.
Unofficial JAX/Flax implementations of deep learning research papers for vision transformers and other architectures.
A pure, immutable module system for JAX that replaces PyTorch-style imperative coding with declarative parameter trees.
A JAX transform that implements LoRA (Low-Rank Adaptation) for efficient fine-tuning of large models with minimal memory overhead.
A compact spiking neural network library built on JAX and Haiku, offering high-performance training via surrogate gradient descent and neuroevolution.
A JAX-based framework for building differentiable numerical simulators with arbitrary discretizations for physical systems.
A symbolic programming library built on JAX for concise, explicit, and optimized machine learning computations.
Add tqdm progress bars to JAX scans and loops using decorators, enabling side-effect-free progress tracking.
An experimental library that converts TensorFlow functions and graphs into JAX functions for reuse and fine-tuning within JAX codebases.
Flax implementations and pretrained checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX.
A Python package providing popular computer vision model architectures built with Equinox for JAX.
A DSL-based library for unified tensor reshaping, squeezing, expanding, and transposing in JAX, TensorFlow, and NumPy.
A lightweight Bayesian optimization library built on JAX for efficient optimization of expensive-to-evaluate functions.
A performant JAX reimplementation of the UniRep model for generating protein sequence representations.
A JAX-based framework for approximate inference in Markov Gaussian processes using iterated Kalman smoothing.
Official JAX implementation of XMC-GAN for text-to-image generation using cross-modal contrastive learning.
A JAX library for building, training, and evaluating normalizing flows for probabilistic modeling.
A Python library for building, training, and deploying spiking neural networks with support for multiple simulation backends and neuromorphic hardware.
A JAX-based differentiable spectral modeling library for exoplanets, brown dwarfs, and M dwarfs.
Kernex extends JAX with kmap and kscan for differentiable stencil computations, enabling efficient array transformations.
A JAX + Flax implementation of physics-inspired graph neural networks for solving combinatorial optimization problems like Max-Cut and Maximum Independent Set.
Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.