A vision transformer architecture that aggregates nested local transformers on image blocks for better accuracy, data efficiency, and convergence.
Nested Hierarchical Transformer (NesT) is a vision transformer architecture that aggregates nested local transformers on image blocks to improve image classification. It addresses limitations in standard vision transformers by enhancing accuracy, data efficiency, and convergence, particularly on benchmarks like ImageNet. The method is designed to scale effectively to smaller datasets, matching the performance of convolutional neural networks.
Researchers and practitioners in computer vision and deep learning who are working on image classification, vision transformer improvements, or efficient model architectures for limited data scenarios.
Developers choose NesT for its simple yet effective hierarchical design that boosts vision transformer performance without complex modifications. Its pre-trained models and compatibility with frameworks like Jax and PyTorch (via timm) make it accessible for both research and practical applications.
Nested Hierarchical Transformer https://arxiv.org/pdf/2105.12723.pdf
Open-Awesome is built by the community, for the community. Submit a project, suggest an awesome list, or help improve the catalog on GitHub.
Achieves high accuracy with less training data than standard vision transformers, as shown by its performance on ImageNet with smaller datasets, closing the gap with convolutional networks.
Optimizes training stability and speed, with pre-trained models like NesT-B reaching 83.8% top-1 accuracy on ImageNet, demonstrating reliable benchmark results.
Designed to match convolutional neural network accuracy even with limited data, making it effective for research or applications where data is scarce.
Includes checkpoints for NesT-B, NesT-S, and NesT-T variants with reported ImageNet accuracies, facilitating quick evaluation and fine-tuning without full retraining.
Requires TPU configuration with specific IP addresses and Jax backend for optimal performance, and the codebase does not support multi-node GPU training beyond 8 GPUs, limiting scalability.
Primary implementation is in Jax, with PyTorch support only through third-party libraries like timm, which may lack full feature parity or official updates.
Explicitly stated as 'not an officially supported Google product,' leading to sparse documentation, minimal support for deployment, and potential breaking changes.