DeepCIFAR is a deep learning project for image classification using the CIFAR-10 dataset.
It leverages PyTorch and transfer learning with a pretrained ResNet18 to classify images into 10 categories:
airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
This project demonstrates a complete end-to-end pipeline:
- Data preprocessing and augmentation
- Training, validation, and test evaluation
- Fine-tuning a deep CNN (ResNet18)
- Visualization of results (curves, confusion matrix, predictions)
- Inference on custom images
- CIFAR-10: 60,000 color images (32x32 pixels) across 10 classes
- Split: 50,000 training + 10,000 test images
- During training, images were resized (128–160 px) and augmented with:
- Random crop & flip
- Color jitter
- Normalization
-
ResNet18 pretrained on ImageNet
-
Last fully-connected layer replaced with 10 output classes
-
Two training modes:
- Frozen backbone (only last layer trained → faster, good baseline)
- Fine-tuning entire network (higher accuracy)
-
Optimizer: AdamW
-
Scheduler: CosineAnnealingLR
-
Loss: CrossEntropyLoss with optional label smoothing
- Device: Apple Silicon GPU (MPS)
- Epochs: 6
- Batch size: 96
- Image size: 160
- Mixed precision: off (for MPS stability)
- Validation Accuracy: 96.62%
- Test Accuracy: 96.02%
Run predictions on a custom image:
python3 inference.py --image samples/cat1.jpg --ckpt results/best.pt --device mpsExample output:
('cat', 0.7955046892166138) ('airplane', 0.029839668422937393) ('automobile', 0.02575363218784332)]
You can also test multiple images at once (placed in samples/) and generate a grid visualization + text report:
python3 inference_demo.py- The script will run inference on all images in
samples/plus 3 random CIFAR-10 test images. - Outputs are saved in
results/sample_preds_custom.pngandresults/sample_preds_custom.txt. - Note: All custom photos used in the demo (
airplane1.jpg,cat1.jpg,dog1.jpg, etc.) were downloaded from Unsplash.
Install dependencies (Mac/CPU/MPS version):
pip install torch torchvision torchaudio
pip install matplotlib scikit-learn tqdmOr simply:
pip install -r requirements.txt- Extend to CIFAR-100 (100 classes)
- Apply Test-Time Augmentation (TTA) for more robust predictions
- Add a GUI application for real-time inference
- Explore other architectures: ResNet34, EfficientNet
Developed by Nikan as part of Deep Learning projects for Co-op applications.
