An implementation of a GPT model in plain jax/jax numpy trained on wine reviews
- Self attention layer implemented from scratch in JAX
- mixed precision support
- multi-GPU support (data parrallel)
- install JAX
- download the dataset and unzip into
datasets/
All entry points to run code are in scripts/
train.pyuses full precision training on a single GPU