This project demonstrates transfer learning and fine-tuning of ResNet-18 and VGG-16 models for multi-class medical image classification using blood cell and chest X-ray datasets.
The project combines two medical imaging datasets:
- Blood Cells Dataset: Classifies white blood cells into EOSINOPHIL, LYMPHOCYTE, MONOCYTE, NEUTROPHIL
- Chest X-ray Pneumonia Dataset: Classifies chest X-rays into NORMAL and PNEUMONIA
Total classes: 6 (EOSINOPHIL, LYMPHOCYTE, MONOCYTE, NEUTROPHIL, NORMAL, PNEUMONIA)
- ResNet-18: Fine-tuned with layer4 and fully connected layer unfrozen
- VGG-16: Fine-tuned with the last convolutional block and classifier unfrozen
- Automatic dataset download and preparation
- Model training with Adam optimizer
- Evaluation with confusion matrices
- Model saving for inference
- Instant prediction on new images
- Comparative visualization of training metrics
- Python 3.7+
- PyTorch
- Torchvision
- Matplotlib
- Seaborn
- Scikit-learn
- Pillow
- Kagglehub
Install dependencies:
pip install torch torchvision matplotlib seaborn scikit-learn pillow kagglehub- Run the training script:
python main.pyThe script will:
- Download and prepare datasets
- Train ResNet-18 and VGG-16 models
- Save trained models as
resnet18_model.pthandvgg16_model.pth - Display training progress and confusion matrices
- Show comparative loss and accuracy plots
- For prediction on new images:
Uncomment and modify the prediction example in
main.py:
prediction = predict_image('resnet18_model.pth', models.resnet18, 'path/to/image.jpg', class_names)
print("Predicted class:", prediction)The models are evaluated on test accuracy, with comparative plots showing:
- Training and test loss curves
- Training and test accuracy curves
Final accuracies are printed for both models.
main.py: Main script for training and evaluationresnet.py: Custom ResNet implementation (not used in main script)dataset/: Local dataset storage (created automatically)resnet18_model.pth: Saved ResNet-18 modelvgg16_model.pth: Saved VGG-16 modelREADME.md: This file
This project is for educational purposes. Datasets are from Kaggle (public domain).
- Blood Cells Dataset: https://www.kaggle.com/paultimothymooney/blood-cells
- Chest X-ray Dataset: https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia
- PyTorch Documentation: https://pytorch.org/