A collection of GPU-accelerated parallel game simulators for reinforcement learning, built with JAX.
Pgx is a library of hardware-accelerated game simulators for reinforcement learning, built natively in JAX. It provides fast, parallelized environments for classic board games and other discrete games, enabling efficient training of RL agents at scale. The project addresses the need for high-performance simulators in discrete action spaces, complementing continuous-space engines like Brax.
Reinforcement learning researchers and practitioners who need fast, parallel environments for training agents on board games or discrete games. It's particularly useful for those using JAX-based ML stacks and seeking GPU/TPU acceleration.
Pgx offers significantly faster simulation speeds compared to CPU-based alternatives by leveraging JAX's parallelization capabilities. Its JAX-native design ensures seamless integration with JAX-based RL algorithms and automatic differentiation, while maintaining compatibility with popular RL APIs like PettingZoo.
♟️ Vectorized RL game environments in JAX
Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.
Leverages JAX for massively parallel execution on GPUs and TPUs, enabling high-speed simulation for large-scale RL training, as highlighted in benchmark comparisons and the project's focus on performance.
All environment step functions are JIT-compilable and support automatic vectorization, ensuring seamless integration with JAX-based ML pipelines, as emphasized in the API description and quick start examples.
Includes a wide range of discrete games from classic board games like Chess and Go to modern ones like 2048, providing a versatile testbed for RL research, as listed in the supported games table.
Environments can be converted to the PettingZoo AEC API, allowing interoperability with existing RL libraries, demonstrated in provided Colab notebooks for export.
The README explicitly admits that some environments, including Go and Chess, do not perform well on TPUs, restricting hardware flexibility and requiring GPU use for optimal performance.
Requires specific installations of JAX and jaxlib based on hardware, which can be cumbersome and error-prone for users not already invested in the JAX ecosystem, as noted in the usage instructions.
The transition from API v1 to v2 introduced signature changes for stochastic environments, affecting reproducibility and requiring code updates, as detailed in the versioning notes, which could disrupt existing projects.