A tutorial demonstrating how to extend JAX with custom C++ and CUDA operations for high-performance computing.
Extending JAX is a tutorial project that demonstrates how to extend JAX with custom C++ and CUDA operations. It provides a step-by-step guide and codebase for integrating existing high-performance C++/CUDA libraries into JAX, enabling JIT compilation and automatic differentiation for specialized algorithms like solving Kepler's equation in astrophysics.
Researchers and developers working in scientific computing, particularly in fields like astrophysics or machine learning, who need to integrate optimized C++/CUDA code into JAX workflows.
It fills a gap in official documentation by providing a complete, working example of extending JAX with custom ops, including both CPU and GPU implementations, build tooling, and testing strategies.
Extending JAX with custom C++ and CUDA code
Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.
Provides a complete workflow from XLA custom call definitions to JAX primitive integration, with commented code in the repository covering both CPU and GPU implementations.
Uses solving Kepler's equation as a practical case study, demonstrating how to handle iterative algorithms and implicit differentiation within JAX's autodiff system.
Shows how to write custom calls for both platforms, including handling multiple outputs and differences in XLA interfaces, as seen in the separate CPU and GPU code examples.
Includes ready-to-use pyproject.toml and CMakeLists.txt files with scikit-build-core, reducing setup time for building and distributing JAX extensions.
The README warns that methods are based on undocumented JAX APIs likely to change, risking breakage in future versions and requiring manual updates.
Requires proficiency in C++, CUDA, XLA, and multiple build tools, making it inaccessible for developers without systems programming experience.
Focuses solely on Kepler's equation, so developers must extrapolate for other op types, which may involve different challenges like data type handling or more complex JVP rules.