The "Awesome JAX" project is a curated resource list designed to support researchers and developers using JAX, a high-performance machine learning library that combines automatic differentiation with XLA compilation. This list includes libraries, tools, tutorials, research papers, and community resources that facilitate the use of JAX for machine learning tasks. It is valuable for both beginners looking to understand the fundamentals and experienced researchers seeking advanced techniques and optimizations. Users can explore a variety of resources to enhance their machine learning projects and leverage the full power of JAX in their research endeavors.
Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.
The "Awesome Tutorials" project is a curated resource list designed to support learners and educators seeking high-quality tutorials across various subjects and technologies. This list includes tutorials for programming languages, web development, data science, machine learning, and more, catering to a wide range of skill levels. Whether you are a beginner looking to grasp the basics or an experienced developer seeking to deepen your knowledge, this collection provides valuable insights and practical examples. Users can explore diverse learning paths and enhance their skills effectively, making it an essential resource for anyone eager to learn.
The "Awesome Core ML Models" project is a curated collection of machine learning models specifically designed for Apple's Core ML framework, which enables developers to integrate machine learning into their iOS and macOS applications. This list includes a variety of pre-trained models for tasks such as image classification, object detection, natural language processing, and more, along with links to their respective repositories and documentation. It is valuable for both beginners looking to implement machine learning in their apps and experienced developers seeking to enhance their projects with advanced capabilities. Users can explore a diverse range of models to find the perfect fit for their application needs, ultimately accelerating their development process and improving user experiences.
The "Awesome AI in Finance" project is a curated collection of resources focused on the intersection of artificial intelligence and finance. This list encompasses a variety of tools, libraries, research papers, case studies, and tutorials that demonstrate how machine learning and AI can be applied to solve complex financial problems. It serves as a valuable resource for finance professionals, data scientists, and researchers looking to leverage AI for tasks such as algorithmic trading, risk assessment, fraud detection, and customer service automation. Whether you're a beginner exploring the field or an experienced practitioner seeking advanced techniques, this collection provides insights and tools to enhance your understanding and application of AI in finance.
The "Awesome ML with Ruby" project is a curated collection of resources aimed at developers interested in applying machine learning techniques using the Ruby programming language. This list encompasses a variety of categories, including libraries for machine learning, frameworks, tutorials, and tools that facilitate the integration of machine learning into Ruby applications. It is particularly beneficial for Ruby developers, data scientists, and machine learning enthusiasts who want to leverage Ruby's capabilities in their projects. Users can discover valuable insights, tools, and community support to enhance their understanding and implementation of machine learning in Ruby.
A neural network library for JAX designed for flexibility, enabling researchers to experiment with new training forms by modifying training loops.
A neural network library for JAX designed for flexibility, enabling researchers to experiment with new training forms by modifying training loops.
A JAX-based neural network library that provides a simple, object-oriented programming model for building and training models.
An object-oriented machine learning framework built on JAX, designed for simplicity and readability in research.
An end-to-end deep learning library focused on clear code, speed, and research, built by Google Brain.
A lightweight library for building and training graph neural networks in JAX, providing graph data structures, utilities, and model implementations.
A high-level neural network API for specifying and analyzing infinite-width neural networks as Gaussian Processes in Python.
A model-definition framework for state-of-the-art machine learning models across text, vision, audio, and multimodal tasks.
A JAX library for neural networks and scientific computing with PyTorch-like syntax and full ecosystem compatibility.
A JAX library for rapid prototyping of large-scale attention-based vision models across images, video, audio, and multimodal data.
A JAX research toolkit for building, editing, and visualizing neural networks as legible, functional pytree data structures.
A JAX-based framework for training large language models with a focus on legibility, scalability, and reproducibility.
A JAX/Flax-based framework for easy and scalable pre-training, fine-tuning, evaluation, and serving of large language models.
A lightweight probabilistic programming library using NumPy and JAX for autograd and JIT compilation to GPU/TPU/CPU.
A library of utilities for writing and testing reliable JAX code, including assertions, debugging tools, and test variants.
A gradient processing and optimization library for JAX, designed for research with composable building blocks.
A JAX-based library providing reinforcement learning building blocks for implementing agents, supporting both on-policy and off-policy learning.
A differentiable, hardware-accelerated molecular dynamics simulation framework built on JAX for computational physics and materials science.
A JAX-native library of probability distributions and bijectors, reimplementing a subset of TensorFlow Probability with emphasis on readability and extensibility.
A Python library for constructing differentiable convex optimization layers in PyTorch, JAX, and MLX using CVXPY.
A Python library built on JAX for studying many-body quantum systems using neural networks and machine learning.
A fast, modular Bayesian inference library for JAX, providing composable samplers for CPU and GPU.
A Python library for probabilistic state space modeling and inference, built on JAX.
A JAX-based library for federated learning simulations that emphasizes ease-of-use in research.
A JAX library for automatically generating equivariant neural network layers for arbitrary symmetry groups via constraint solving.
Flax implementations and pretrained checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX.
A pure, immutable module system for JAX that replaces PyTorch-style imperative coding with declarative parameter trees.
A JAX library for nonlinear optimization including root finding, minimization, fixed points, and least squares.
Hardware-accelerated, batchable, and differentiable optimization algorithms implemented in JAX for machine learning research.
A performant JAX reimplementation of the UniRep model for generating protein sequence representations.
A JAX library for distributions, bijections, and normalizing flows implemented as Equinox modules.
A differentiable cosmology library built with JAX for automatic differentiation of cosmological calculations.
A library for probabilistic reasoning and statistical analysis integrated with TensorFlow and JAX.
A Python toolbox for solving optimal transport problems with JAX-powered computational efficiency.
A photovoltaic simulator with automatic differentiation for solar cell modeling and optimization, built on JAX.
A JAX library implementing Lie groups for rigid body transformations in computer vision and robotics.
A fast, differentiable physics engine built with JAX for massively parallel rigid body simulation on accelerator hardware.
A collection of pretrained deep learning models (StyleGAN2, GPT2, VGG, ResNet) for the Jax/Flax ecosystem.
A JAX-based differentiable spectral modeling library for exoplanets, brown dwarfs, and M dwarfs.
An image processing library built on JAX, designed to be optimized and parallelized with JAX transformations.
A lightweight Bayesian optimization library built on JAX for efficient optimization of expensive-to-evaluate functions.
A JAX-based framework for building differentiable numerical simulators with arbitrary discretizations for physical systems.
A Python library that enables mathematical operations on JAX pytrees, allowing numerical algorithms to work directly with complex nested data structures.
Unofficial JAX/Flax implementations of deep learning research papers for vision transformers and other architectures.
A JAX-based library for loopy belief propagation on discrete factor graphs, enabling efficient probabilistic inference.
A scalable, hardware-accelerated neuroevolution toolkit built on JAX for parallel training across TPUs/GPUs.
A comprehensive, high-performance library implementing 30+ Evolution Strategies in JAX for scalable optimization on modern hardware.
A symbolic programming library built on JAX for concise, explicit, and optimized machine learning computations.
A JAX-powered probabilistic programming library focused on performant sampling methods for Bayesian inference on CPU, GPU, and TPU.
A DSL-based library for unified tensor reshaping, squeezing, expanding, and transposing in JAX, TensorFlow, and NumPy.
A collection of research code and datasets released by Google Research under open licenses.
A JAX-based library providing numerical differential equation solvers for ODEs, SDEs, and CDEs with autodifferentiation and GPU support.
An extremely lightweight Gaussian Process library for Python built on JAX with GPU acceleration and automatic differentiation.
A JAX-based library providing accelerated reinforcement learning environments with full compatibility to the classic gym API.
A JAX-native library implementing Monte Carlo tree search algorithms like AlphaZero and MuZero for reinforcement learning research.
A JAX library for second-order optimization of neural networks using the K-FAC curvature approximation algorithm.
An experimental library that converts TensorFlow functions and graphs into JAX functions for reuse and fine-tuning within JAX codebases.
A JAX-based research framework for differentiable and parallelizable acoustic simulations, running on CPU, GPU, and TPU.
A low-level Gaussian process framework in JAX and Flax, designed for maximum flexibility and close alignment with mathematical notation.
A diverse suite of scalable reinforcement learning environments written in JAX for hardware-accelerated research.
A Python package providing popular computer vision model architectures built with Equinox for JAX.
A provable, measurable secure computation device that enables privacy-preserving tensor operations using multi-party computation (MPC).
Add tqdm progress bars to JAX scans and loops using decorators, enabling side-effect-free progress tracking.
Kernex extends JAX with kmap and kscan for differentiable stencil computations, enabling efficient array transformations.
A high-performance, scalable LLM library and reference implementation written in pure Python/JAX for training on TPUs and GPUs.
A JAX-based machine learning framework for configuring and training large-scale models with high efficiency on TPUs and GPUs.
A layer library for JAX-based machine learning projects, optimized for large-scale ML.
High-performance, end-to-end reinforcement learning implementations fully written in JAX for massive parallelization on GPUs.
A JAX transform that implements LoRA (Low-Rank Adaptation) for efficient fine-tuning of large models with minimal memory overhead.
A Python package built on JAX for solving inverse problems in scientific imaging using optimization and prior models.
A compact spiking neural network library built on JAX and Haiku, offering high-performance training via surrogate gradient descent and neuroevolution.
A flexible, efficient, and extensible JIT-compiled framework for computational neuroscience and brain-inspired computation.
A JAX-powered library for solving large-scale optimal transport problems, including matching, barycenters, and neural approximations.
A hardware-accelerated Python library for running Quality-Diversity and neuroevolution algorithms in minutes instead of days.
A collection of CI pipelines, Docker images, and optimized examples to simplify JAX development on NVIDIA GPUs.
A collection of GPU-accelerated parallel game simulators for reinforcement learning, built with JAX.
A JAX-based framework for streamlined training, fine-tuning, and high-performance serving of large language and multimodal models.
A differentiable, massively parallel Lattice Boltzmann library in Python for physics-based machine learning and fluid dynamics simulations.
A Python library for GPU-accelerated and differentiable quantum systems simulation built with JAX.
A JAX-powered reimplementation of MiniGrid offering over 1000x speedup for reinforcement learning experiments.
An efficient open-source Python package for 3D photonic nanostructure simulation and design using GPU-accelerated FDTD with automatic differentiation.
A FlashAttention 2 implementation for JAX with block-wise document mask optimization and context parallelism for efficient long-sequence training.
A PyTorch frontend for JAX that enables running PyTorch code on TPUs and provides seamless PyTorch-JAX interoperability.
A technique using Fourier feature mappings to enable neural networks to learn high-frequency functions in low-dimensional domains.
A JAX-based framework for approximate inference in Markov Gaussian processes using iterated Kalman smoothing.
A JAX-based probabilistic programming framework using nested sampling for fast Bayesian inference and evidence computation.
A collection of research code and datasets released by Google Research under open licenses.
Official JAX implementation of Mip-NeRF, a multiscale neural radiance field model for anti-aliased novel view synthesis.
Official repository for Big Transfer (BiT) models, providing pre-trained visual representations for efficient transfer learning across computer vision tasks.
JAX (Flax) implementations of reinforcement learning algorithms for continuous action spaces, designed for research.
A vision transformer architecture that aggregates nested local transformers on image blocks for better accuracy, data efficiency, and convergence.
Official JAX implementation of XMC-GAN for text-to-image generation using cross-modal contrastive learning.
Official JAX/Flax implementation of Vision Transformer (ViT) and MLP-Mixer for image recognition, with pre-trained models.
A JAX + Flax implementation of physics-inspired graph neural networks for solving combinatorial optimization problems like Max-Cut and Maximum Independent Set.
Open source implementation of AlphaFold 2, a deep learning system for highly accurate protein structure prediction.
A collection of implementations and illustrative code accompanying DeepMind's published research papers across AI and machine learning.
A collection of implementations and illustrative code accompanying DeepMind's published research papers across AI and scientific domains.
A JAX library for building, training, and evaluating normalizing flows for probabilistic modeling.