Skip to content

hachoj/Pixart-Alpha-From-Scratch

Repository files navigation

PixArt/SD3-style Rectified Flow from scratch

A from scratch implementation of a PixArt/SD3-inspired Rectified Flow text to image model, with a staged curriculum (Stage 1 $\rightarrow$ Stage 5), efficient reparameterization (PixArt-style), and an FP8-first training stack (MXFP8 + Transformer Engine). Includes an interactive Gradio sampler for qualitative inspection and checkpoint verification.

Related / earlier JAX+Equinox+Grain version: https://github.com/hachoj/PixelArt-Alpha-Equinox-Grain


Highlights

  • Rectified Flow training in latent space (SD3 VAE latents; $\alpha$ weighted euler sampling with torchdiffeq)
  • PixArt-style reparameterization utility (reparameterize.py) to improve efficiency and stage transitions
  • FP8 training path (MXFP8 + NVIDIA Transformer Engine) with fused wgrad accumulation
  • Staged curriculum with separate training entrypoints: train_stage1.pytrain_stage5.py
  • Hydra configuration system under configs/ for reproducible runs
  • High-throughput caption/sharding example in data/shard_caption_example.py (vLLM + WebDataset)

Built from scratch:

  • DiT/MMDiT blocks (models/ccDiT, models/mmDiT) + attention/MLP modules
  • DDP training loops + EMA + logging hooks (train_stage*.py)
  • Latent shard dataloaders (data/data.py)
  • Qualitative sampling utilities + prompt controls (app/)

Used libraries for:

  • Transformer Engine FP8 primitives (transformer_engine)
  • Stable Diffusion 3 VAE (diffusers, see models/vae.py)
  • ODE integration (torchdiffeq)
  • Text encoders/tokenizers (transformers, defined via config)
  • UI (gradio)
  • Optional data-generation stack (vllm, webdataset, Pillow, torchvision)

Results

A few representative generations and comparisons live in figures/:

Sample output Sample output Sample output
Sample output Sample output Sample output
Sample output Sample output Sample output

Quickstart

1) Environment

This project assumes:

  • Python: 3.10+
  • CUDA + GPU: required for training; FP8 path expects H100/H200/B200-class GPUs where MXFP8 is only available for Blackwell and newer architectures.

2) Install

There is no pinned requirements.txt because Transformer Engine and vLLM must match your CUDA/driver stack.

Minimal training + sampling deps (install torch/torchvision per your CUDA build):

pip install hydra-core omegaconf numpy einops jaxtyping torchdiffeq transformers diffusers wandb gradio
# Install NVIDIA Transformer Engine separately for your system (MXFP8 + te.Linear).

Optional for dataset creation:

pip install pillow torchvision webdataset vllm qwen-vl-utils

Weights / checkpoints

This repo does not include weights. The Gradio app expects a local checkpoint file via --model-path.

Checkpoint format Stage checkpoints contain:

  • Model weights
  • EMA model weights
  • Optimizer state

app/gradio_app.py will load the appropriate weights internally (see the script for details).

For access to model weights, email chojnowski.h@ufl.edu.


Training (Stages 1–5)

All train_stage*.py scripts are DDP-first and must be launched with torchrun (they read RANK/LOCAL_RANK/WORLD_SIZE).

Example (single node):

# Stage 1 (class-conditioned; ccDiT)
torchrun --standalone --nproc_per_node=8 train_stage1.py data=stage1 train=stage1 model=ccdit

# Stage 2–5 (text-conditioned; mmDiT)
torchrun --standalone --nproc_per_node=8 train_stage2.py data=stage2 train=stage2 model=mmdit

Notes:

  • configs/config.yaml defaults to Stage 5; override data=... train=... model=... on the CLI.
  • Data directories like data/stage*/ are expected to exist locally and contain your latent shards; they are intentionally not tracked.

Sampling / Demo (Gradio)

Run the Stage 5 sampler:

python app/gradio_app.py --model-path /path/to/stage5_checkpoint.pt

By default, the app launches locally. If you want a public share link, opt in explicitly:

python app/gradio_app.py --model-path /path/to/stage5_checkpoint.pt --share

Data generation example

data/shard_caption_example.py shows the structure used to:

  • stream WebDataset tar shards
  • caption with a vLLM-served VLM (generates both long and short captions)
  • tokenize captions for the text encoder
  • encode images into VAE latents
  • save fixed-size .pt shards (latents + token ids + attention masks)

You must provide your own dataset URL template:

python data/shard_caption_example.py \
  --base-url 'https://your-host/path/train_{i:04d}.tar' \
  --num-shards 100 \
  --out-dir data/stageX/example_dataset

Repository layout

  • models/ — model components (attention, MLP, DiT/MMDiT, VAE wrapper)
  • reparameterize.py — reparameterization logic
  • train_stage*.py — stage-specific training entry points
  • configs/ — Hydra configs (train/data/model/optim/vae/wandb)
  • data/ — datasets, sharding scripts, stage manifests
  • app/ — Gradio sampler + prompt utilities
  • scripts/ — SLURM scripts and helpers (sharding + training)
  • figures/ — sample grids used in this README

References

BibTeX lives in bib.bib. Primary starting points:

  • Chen et al., “PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis”, 2023 (arXiv:2310.00426).
  • Esser et al., “Scaling Rectified Flow Transformers for High-Resolution Image Synthesis”, ICML 2024.

See bib.bib for the full list used during development.


License

MIT. See LICENSE.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors