A DIY project to teach myself JAX. Currently work in progress, the plan is to get the original BERT performance with a clean architecture built from scratch and a low (personal) budget for training.
- Training on English Wikipedia and Books3. Original BookCorpus is not a good dataset (see below)
- Finetuning on GLUE tasks (MNLI currently)
- Tokenizer trained from scratch
Raw and finetuned checkpoints are available via git LFS.
Current MNLI scores (m/mm): 80.7 / 80.0, as reported by the GLUE dashboard.
The training overview can be found in the training log.
- No NSP loss, as it does not boost performance (see papers below)
- Training only for 1 epoch, so no dropout and very little regularization in general
- Pre-LayerNorm transformer for training stability
- No masking, only using contiguous chunks of fixed length to increase per-step efficiency and simplify code
- Gradient accumulation to fit larger batch sizes, linear batch-size schedule (rationale: gradient noise scale, see below)
- One Cycle learning rate schedule for efficient training under constrained budget
- Mixed precision with loss scaling
- Training a tokenizer using the HuggingFace WordPiece training algorithm, which is actually just BPE. Inference rules are still from WordPiece
Jax, optax, Equinox. HuggingFace datasets and tokenizers, Omegaconf for config, Weights&Biases for monitoring, numpy memmap for disk access, Pytest for extensive unit testing.
- Attention Is All You Need
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- RoBERTa: A Robustly Optimized BERT Pretraining Approach
- How to Train BERT with an Academic Budget
- Cramming: Training a Language Model on a Single GPU in One Day
- On Layer Normalization in the Transformer Architecture
- Mixed Precision Training
- How AI training scales
- Scaling Laws for Neural Language Models
- The Pile: An 800GB Dataset of Diverse Text for Language Modeling
- What the BookCorpus?
- GLUE: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding
- A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference