A JAX library for neural networks and scientific computing with PyTorch-like syntax and full ecosystem compatibility.
Equinox is a JAX library that provides elegant, easy-to-use tools for building neural networks and doing scientific computing. It extends core JAX with PyTorch-like syntax for model definition, filtered transformation APIs, PyTree manipulation utilities, and advanced features like runtime errors. Unlike some other JAX libraries, Equinox isn't a framework—everything you write remains fully compatible with the broader JAX ecosystem.
Machine learning researchers and scientific computing practitioners who work with JAX and want PyTorch-like syntax without sacrificing JAX's composability and transformation model. It's particularly valuable for those who need fine-grained control over their models and transformations.
Equinox offers the best of both worlds: the intuitive, familiar syntax of PyTorch combined with JAX's powerful functional transformations and ecosystem compatibility. Its non-framework approach means users never get locked in, and its advanced features like filtered transformations and PyTree utilities provide capabilities not found in simpler alternatives.
Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.
Models are defined using familiar class-based syntax with eqx.Module, making it easy for PyTorch users to transition to JAX while maintaining full control, as shown in the Linear layer example.
Provides fine-grained APIs for applying JAX transformations like JIT, grad, and vmap with selective filtering, enabling complex operations without boilerplate, as highlighted in the filtered transformations feature.
Everything is just a PyTree, ensuring full interoperability with any other JAX library or custom code, avoiding framework lock-in and allowing smooth integration, as emphasized in the philosophy.
Includes runtime error support that works through JAX transformations, aiding in debugging complex functional programs, which is a noted advanced feature in the README.
Does not provide built-in training loops, loss functions, or pre-trained models, requiring users to manually implement these or rely on external libraries like Optax, increasing initial setup effort.
While syntax is PyTorch-like, users must understand JAX's functional paradigm, PyTrees, and transformations to use Equinox effectively, which can be challenging without prior JAX experience.
Compared to full frameworks, Equinox has fewer out-of-the-box tools for common tasks like model checkpointing or distributed training, necessitating integration with other JAX libraries for a complete workflow.