A PyTorch-based computer vision project implementing both Vision Transformers (ViT) and Convolutional Neural Networks (CNNs) for image classification on the CIFAR-10 dataset.
Training progress visualization showing the convergence of our Vision Transformer over 10 epochs
Visualization of learned attention maps after training
- Analysis
- Code
- License
- Acknowledgments
We compared a baseline CNN against Vision Transformers (ViTs) trained from scratch on CIFAR-10 with varying patch sizes. We trained a baseline convolutional neural network (CNN) and several Vision Transformers (ViTs) with varying patch sizes on the CIFAR-10 dataset. All models were trained from scratch using the same optimizer, learning rate schedule, and number of epochs to ensure comparability.
| Model | Test Loss | Test Accuracy |
|---|---|---|
| CNN | 0.962 | 67.0% |
| ViT (patch=2) | 1.272 | 63.1% |
| ViT (patch=4) | 1.394 | 60.1% |
| ViT (patch=6) | 1.675 | 56.6% |
| ViT (patch=8) | 1.872 | 55.3% |
We also computed bootstrapped accuracy differences relative to the CNN baseline:
The CNN consistently outperforms all ViTs, and performance decreases steadily with larger patch sizes.
The curves highlight two trends:
- The CNN converges faster and reaches higher validation accuracy.
- Smaller patch sizes (e.g., 2×2) help ViTs retain more local detail and improve performance relative to larger patches, but they still lag behind CNNs.
- Inductive biases matter: On small datasets like CIFAR-10, CNNs have strong locality and translation equivariance built in, giving them an advantage over ViTs.
- ViTs need more data or pretraining: In low-data regimes, ViTs trained from scratch underperform CNNs, but with sufficient scale or augmentations they can surpass CNNs.
- Patch size tradeoff: Smaller patches preserve detail and perform better, while larger patches discard local structure and hurt accuracy.
In short: on CIFAR-10, a simple CNN beats a ViT trained from scratch, but the ViT results align with known trends from the literature.
This project provides a complete training pipeline for image classification using PyTorch, featuring:
- CNN Architecture: A custom convolutional neural network designed for CIFAR-10
- Vision Transformer (ViT): A fully functional transformer-based architecture with attention mechanisms
- Training Framework: A robust
Runnerclass that handles training, validation, and testing - Data Management: Automatic CIFAR-10 dataset download and preprocessing
- Visualization: Built-in plotting capabilities with animated training progress GIFs and attention map visualization
- Model Persistence: Save and load functionality for trained models
- Interactive Demos: Jupyter notebooks for tutorials and attention visualization
vision-transformer/
├── network.py # Neural network architectures (CNN, FeedForward, ViT)
├── runner.py # Training orchestration and metrics tracking
├── train.py # Main training script with command-line arguments
├── utils.py # Utility functions for GIF generation
├── requirements.txt # Python dependencies
├── .gitignore # Git ignore rules
├── data/ # CIFAR-10 dataset storage
│ └── cifar-10-batches-py/ # CIFAR-10 data files
├── demos/ # Interactive tutorials and demonstrations
│ ├── vit_tutorial.ipynb # Vision Transformer code explanation
│ ├── attention_maps.ipynb # Attention visualization tutorial
│ └── model_comparison.ipynb # Model performance comparison analysis
├── figures/ # Generated plots and visualizations
│ ├── attention_maps.png # Attention map visualization
│ ├── model_comparison_bootstrap.png # Bootstrap analysis results
│ └── model_comparison_curves.png # Training curves comparison
├── saved/ # Trained model checkpoints
│ ├── cnn.pt/.pkl # CNN model files
│ └── vit_*.pt/.pkl # ViT models with various configurations
├── animation_vit.gif # Training progress animation
└── .venv/ # Virtual environment (excluded from git)
- Neural Network Architectures: CNN, FeedForward, and Vision Transformer (ViT) implementations
- Training Framework: Complete training pipeline with metrics tracking and visualization
- Data Management: Automatic CIFAR-10 dataset download and preprocessing
- Visualization: Training progress animations and attention map visualization
- Interactive Demos: Jupyter notebooks for tutorials and model exploration
- Python 3.8+
- PyTorch 2.0+
- Apple Silicon Mac (for MPS acceleration) or CUDA-capable GPU
- Jupyter Notebook (for interactive demos)
-
Clone the repository
git clone <repository-url> cd vision-transformer
-
Create and activate virtual environment
python -m venv .venv source .venv/bin/activate # On macOS/Linux # or .venv\Scripts\activate # On Windows
-
Install dependencies
pip install -r requirements.txt
-
Install Jupyter for interactive demos
pip install jupyter
Run the main training script with command-line arguments:
# Train CNN (default)
python train.py --model cnn --epochs 20
# Train Vision Transformer with custom parameters
python train.py --model vit --epochs 20 --hidden_size 64 --num_heads 3 --num_blocks 10 --patch_size 2
# Train MLP/FeedForward network
python train.py --model mlp --epochs 20Available Arguments:
--model: Architecture choice (cnn,vit, ormlp)--epochs: Number of training epochs (default: 20)--lr: Learning rate (default: 0.0005)--batch_size: Batch size (default: 128)--hidden_size: Hidden dimension for ViT (default: 64)--num_heads: Number of attention heads for ViT (default: 3)--num_blocks: Number of transformer blocks for ViT (default: 10)--patch_size: Patch size for ViT (default: 2)
This will:
- Download CIFAR-10 dataset (if not already present)
- Train the specified model for the given number of epochs
- Display training progress with live metrics
- Generate training plots and animated GIF
- Save the trained model with descriptive filename
- Report final test performance
Start Jupyter Notebook server:
cd demos
jupyter notebookAvailable Tutorials:
- ViT Tutorial: Learn about Vision Transformer architecture and components
- Attention Maps: Visualize attention weights and understand model decisions
Modify train.py or use command-line arguments to:
- Change the number of epochs (default: 20)
- Use different models (CNN, FeedForward, or VisionTransformer)
- Adjust hyperparameters (learning rate, batch size, model architecture)
- Change the validation split ratio (currently 80/20 train/val split)
- Customize ViT architecture (hidden size, attention heads, transformer blocks, patch size)
To train with the Vision Transformer architecture:
from network import VisionTransformer
# Create ViT model with custom parameters
model = VisionTransformer(
img_size=32, # CIFAR-10 image size
hidden_size=64, # Embedding dimension
output_size=10, # Number of classes
num_heads=3, # Number of attention heads
num_blocks=10, # Number of transformer blocks
patch_size=2 # Patch size for image division
)
# Use with existing training pipeline
runner = Runner(model, optimizer, criterion, device)
runner.train(train_loader, val_loader, epochs=20)After training a model, use the attention maps tutorial:
# Load trained model
model = VisionTransformer.load("saved/VisionTransformer.pt")
# Extract attention weights
with torch.no_grad():
# Forward pass to get attention weights
output = model(images, return_attention=True)
attention_weights = output['attention_weights']
# Visualize attention maps
visualize_attention_maps(images, attention_weights)CNN Architecture Details:
Input: 3x32x32 (RGB image)
├── Conv2d(3→16, kernel=3x3, padding=1) + ReLU
├── MaxPool2d(2x2)
├── Conv2d(16→32, kernel=3x3, padding=1) + ReLU
├── MaxPool2d(2x2)
├── Flatten: 32×8×8 → 2048
└── Linear(2048→10) → Output
Vision Transformer (ViT) Architecture Details:
Input: 3x32x32 (RGB image)
├── Patch Embedding:
│ ├── Conv2d(3→64, kernel=2x2, stride=2) → 16x16 patches (256 patches total)
│ ├── Flatten patches → 256 patches × 64 dimensions
│ ├── Add CLS token → 257 patches × 64 dimensions
│ └── Add positional embeddings
├── Transformer Blocks (10 blocks):
│ ├── LayerNorm + Multi-Head Attention (3 heads)
│ ├── Residual connection
│ ├── LayerNorm + Feed-Forward Network (GELU activation)
│ └── Residual connection
├── Extract CLS token representation
└── Linear(64→10) → Output
Key ViT Components:
- Patch Embedding: Divides 32×32 images into 2×2 patches (256 patches total)
- CLS Token: Learnable classification token prepended to patch sequence
- Positional Embeddings: Learnable positional information for each patch + CLS token
- Multi-Head Attention: 3 attention heads for different feature aspects
- Transformer Blocks: Stack of 10 self-attention and feed-forward layers with residual connections
- Layer Normalization: Stabilizes training and improves convergence
- GELU Activation: Smooth activation function used in modern transformers
- PyTorch (≥2.0.0): Deep learning framework
- TorchVision (≥0.15.0): Computer vision utilities
- Matplotlib (≥3.5.0): Plotting and visualization
- NumPy (≥1.21.0): Numerical computing
- Pillow (≥8.0.0): Image processing
- Jupyter: Interactive notebook environment
- tqdm: Progress bars (installed separately)
This project is licensed under the MIT License - see below for details:
MIT License
Copyright (c) 2024 Jordan Lei
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
- CIFAR-10 dataset creators
- PyTorch development team
- Apple Silicon MPS support
- Jupyter project contributors
Last Updated: December 2024 Status: Active Development Version: 1.1.0



