Skip to content

Latest commit

 

History

History
45 lines (34 loc) · 2.76 KB

README.md

File metadata and controls

45 lines (34 loc) · 2.76 KB

Minimal BERT in Jax

A DIY project to teach myself JAX. Currently work in progress, the plan is to get the original BERT performance with a clean architecture built from scratch and a low (personal) budget for training.

  • Training on English Wikipedia and Books3. Original BookCorpus is not a good dataset (see below)
  • Finetuning on GLUE tasks (MNLI currently)
  • Tokenizer trained from scratch

Checkpoints & training logs

Raw and finetuned checkpoints are available via git LFS.

Current MNLI scores (m/mm): 80.7 / 80.0, as reported by the GLUE dashboard.

The training overview can be found in the training log.

Architectural changes

  • No NSP loss, as it does not boost performance (see papers below)
  • Training only for 1 epoch, so no dropout and very little regularization in general
  • Pre-LayerNorm transformer for training stability
  • No masking, only using contiguous chunks of fixed length to increase per-step efficiency and simplify code
  • Gradient accumulation to fit larger batch sizes, linear batch-size schedule (rationale: gradient noise scale, see below)
  • One Cycle learning rate schedule for efficient training under constrained budget
  • Mixed precision with loss scaling
  • Training a tokenizer using the HuggingFace WordPiece training algorithm, which is actually just BPE. Inference rules are still from WordPiece

Tech stack

Jax, optax, Equinox. HuggingFace datasets and tokenizers, Omegaconf for config, Weights&Biases for monitoring, numpy memmap for disk access, Pytest for extensive unit testing.

Main reference papers