A from scratch implementation of a PixArt/SD3-inspired Rectified Flow text to image model, with a staged curriculum (Stage 1
Related / earlier JAX+Equinox+Grain version: https://github.com/hachoj/PixelArt-Alpha-Equinox-Grain
-
Rectified Flow training in latent space (SD3 VAE latents;
$\alpha$ weighted euler sampling withtorchdiffeq) -
PixArt-style reparameterization utility (
reparameterize.py) to improve efficiency and stage transitions - FP8 training path (MXFP8 + NVIDIA Transformer Engine) with fused wgrad accumulation
-
Staged curriculum with separate training entrypoints:
train_stage1.py…train_stage5.py -
Hydra configuration system under
configs/for reproducible runs -
High-throughput caption/sharding example in
data/shard_caption_example.py(vLLM + WebDataset)
Built from scratch:
- DiT/MMDiT blocks (
models/ccDiT,models/mmDiT) + attention/MLP modules - DDP training loops + EMA + logging hooks (
train_stage*.py) - Latent shard dataloaders (
data/data.py) - Qualitative sampling utilities + prompt controls (
app/)
Used libraries for:
- Transformer Engine FP8 primitives (
transformer_engine) - Stable Diffusion 3 VAE (
diffusers, seemodels/vae.py) - ODE integration (
torchdiffeq) - Text encoders/tokenizers (
transformers, defined via config) - UI (
gradio) - Optional data-generation stack (
vllm,webdataset,Pillow,torchvision)
A few representative generations and comparisons live in figures/:
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
This project assumes:
- Python: 3.10+
- CUDA + GPU: required for training; FP8 path expects H100/H200/B200-class GPUs where MXFP8 is only available for Blackwell and newer architectures.
There is no pinned requirements.txt because Transformer Engine and vLLM must match your CUDA/driver stack.
Minimal training + sampling deps (install torch/torchvision per your CUDA build):
pip install hydra-core omegaconf numpy einops jaxtyping torchdiffeq transformers diffusers wandb gradio
# Install NVIDIA Transformer Engine separately for your system (MXFP8 + te.Linear).Optional for dataset creation:
pip install pillow torchvision webdataset vllm qwen-vl-utilsThis repo does not include weights. The Gradio app expects a local checkpoint file via --model-path.
Checkpoint format Stage checkpoints contain:
- Model weights
- EMA model weights
- Optimizer state
app/gradio_app.py will load the appropriate weights internally (see the script for details).
For access to model weights, email chojnowski.h@ufl.edu.
All train_stage*.py scripts are DDP-first and must be launched with torchrun (they read RANK/LOCAL_RANK/WORLD_SIZE).
Example (single node):
# Stage 1 (class-conditioned; ccDiT)
torchrun --standalone --nproc_per_node=8 train_stage1.py data=stage1 train=stage1 model=ccdit
# Stage 2–5 (text-conditioned; mmDiT)
torchrun --standalone --nproc_per_node=8 train_stage2.py data=stage2 train=stage2 model=mmditNotes:
configs/config.yamldefaults to Stage 5; overridedata=... train=... model=...on the CLI.- Data directories like
data/stage*/are expected to exist locally and contain your latent shards; they are intentionally not tracked.
Run the Stage 5 sampler:
python app/gradio_app.py --model-path /path/to/stage5_checkpoint.ptBy default, the app launches locally. If you want a public share link, opt in explicitly:
python app/gradio_app.py --model-path /path/to/stage5_checkpoint.pt --sharedata/shard_caption_example.py shows the structure used to:
- stream WebDataset tar shards
- caption with a vLLM-served VLM (generates both long and short captions)
- tokenize captions for the text encoder
- encode images into VAE latents
- save fixed-size
.ptshards (latents + token ids + attention masks)
You must provide your own dataset URL template:
python data/shard_caption_example.py \
--base-url 'https://your-host/path/train_{i:04d}.tar' \
--num-shards 100 \
--out-dir data/stageX/example_datasetmodels/— model components (attention, MLP, DiT/MMDiT, VAE wrapper)reparameterize.py— reparameterization logictrain_stage*.py— stage-specific training entry pointsconfigs/— Hydra configs (train/data/model/optim/vae/wandb)data/— datasets, sharding scripts, stage manifestsapp/— Gradio sampler + prompt utilitiesscripts/— SLURM scripts and helpers (sharding + training)figures/— sample grids used in this README
BibTeX lives in bib.bib. Primary starting points:
- Chen et al., “PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis”, 2023 (arXiv:2310.00426).
- Esser et al., “Scaling Rectified Flow Transformers for High-Resolution Image Synthesis”, ICML 2024.
See bib.bib for the full list used during development.
MIT. See LICENSE.

.png)
.png)
.png)

.png)

.png)
.png)