Skip to content

riyashet-hds/retinal-fundus-classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Retinal Fundus Classification

Transfer Learning for Diabetic Retinopathy Grading: Comparing CNN and Vision Transformer Approaches

Python 3.8+ PyTorch License: MIT

Built by Riya Shet | MSc Health Data Science, University of Birmingham


Overview

This project compares three deep learning approaches for grading diabetic retinopathy (DR) severity from retinal fundus photographs, using the APTOS 2019 Blindness Detection dataset (3,662 images, 5 severity levels).

Model Approach Quadratic Weighted Kappa Accuracy
Baseline CNN 3-layer CNN, trained from scratch 0.4697 58.6%
ResNet-50 ImageNet pre-trained, fine-tuned 0.8877 78.6%
ViT-B/16 ImageNet pre-trained, fine-tuned 0.8839 81.4%

Transfer learning nearly doubled the baseline's kappa score. ResNet-50 achieved the highest kappa (0.89), while ViT-B/16 achieved the highest accuracy (81.4%). Both pre-trained models exceeded the 0.85 QWK threshold for "strong agreement".

Key Features

  • Ben Graham preprocessing -- crop, resize, and colour-normalise fundus images to enhance vascular detail
  • Three architectures -- Baseline CNN (from scratch), ResNet-50 (transfer learning), ViT-B/16 (transfer learning)
  • Mixed precision (AMP) training -- ~2x speedup on T4 GPUs via automatic mixed precision
  • Class-imbalance handling -- inverse-frequency weighted cross-entropy loss
  • Grad-CAM explainability -- heat map visualisations showing where each model focuses its attention, with ViT-compatible reshape transform
  • Multi-platform support -- auto-detects Google Colab, Kaggle Notebooks, and local environments
  • Stratified train/val/test split (70/15/15) with data augmentation via albumentations

Dataset

The APTOS 2019 Blindness Detection dataset contains 3,662 retinal fundus images labelled by severity:

Label Severity Description
0 No DR Healthy retina
1 Mild Microaneurysms
2 Moderate Hemorrhages and exudates
3 Severe Widespread vascular damage
4 Proliferative Neovascularisation (most dangerous)

To download: Create a Kaggle account, accept the competition rules, then download via the Kaggle API:

kaggle competitions download -c aptos2019-blindness-detection

Place train.csv and the train_images/ folder in the same directory as the notebook.

Quick Start

# Clone the repository
git clone https://github.com/riyashet-hds/retinal-fundus-classification.git
cd retinal-fundus-classification

# Install dependencies
pip install -r requirements.txt

# Download the APTOS dataset (requires Kaggle API)
kaggle competitions download -c aptos2019-blindness-detection
unzip aptos2019-blindness-detection.zip

# Open the notebook
jupyter notebook retinal_classification.ipynb

The notebook auto-detects Google Colab, Kaggle Notebooks, and local environments and adjusts paths accordingly.

Project Structure

retinal-fundus-classification/
├── retinal_classification.ipynb   # Full pipeline: preprocessing, training, evaluation, Grad-CAM
├── requirements.txt               # Python dependencies
├── LICENSE                        # MIT License
└── README.md

Results Highlights

  • Transfer learning gap: Pre-trained models improved QWK by ~0.42 over the from-scratch baseline
  • ResNet-50 vs ViT-B/16: ResNet-50 achieved the highest kappa (0.89 vs 0.88), while ViT-B/16 achieved the highest accuracy (81.4% vs 78.6%) -- different error patterns explain the divergence
  • Hardest classes: Severe DR (F1 = 0.44--0.46) and Proliferative DR (F1 = 0.48--0.55) were most challenging due to rarity and visual ambiguity with adjacent classes
  • Fastest convergence: ViT peaked at epoch 8 (QWK = 0.91) and triggered early stopping at epoch 13 due to overfitting
  • Grad-CAM insights: ResNet produces single concentrated activation regions; ViT produces multi-focal distributed attention across multiple anatomical landmarks

References

  1. Zhou, Y. et al. (2023). A Foundation Model for Generalizable Disease Detection from Retinal Images. Nature.
  2. Oquab, M. et al. (2024). DINOv2: Learning Robust Visual Features without Supervision. TMLR.
  3. Graham, B. (2015). Kaggle Diabetic Retinopathy Detection Competition, 1st Place Solution.
  4. Selvaraju, R.R. et al. (2017). Grad-CAM: Visual Explanations from Deep Networks. ICCV 2017.
  5. APTOS 2019 Blindness Detection. Kaggle Competition.

License

This project is licensed under the MIT License.

About

A transfer learning project on retinal disease classification, comparing CNN and Vision Transformer Approaches on APTOS Diabetic Retinopathy Data

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors