Skip to content

Latest commit

 

History

History
298 lines (211 loc) · 8.46 KB

File metadata and controls

298 lines (211 loc) · 8.46 KB

HipMRI_Study Segmentation with Improved U-Net (Task 3)

Author

Marcus Baulch (47445464)
COMP3710 - Pattern Recognition and Analysis The University of Queensland

Overview

This project implements an Improved U-Net for multi-class semantic segmentation of Prostate MRI images. The model segments anatomical structures in 2D MRI slices into 6 distinct classes, achieving performance through residual connections, batch normalisation, and a combined loss function.

Key Features

  • Residual U-Net Architecture: Enhanced U-Net with ResNet-style skip connections within encoder/decoder blocks
  • Multi-Class Segmentation: 6-class semantic segmentation
  • Combined Loss Function: 60% Dice Loss + 40% Cross-Entropy for balanced optimisation
  • Data Augmentation: Random flips and rotations during training
  • Comprehensive Evaluation: Dice coefficient metrics with visualisation capabilities

Dataset and Preprocessing

Hip MRI Study Dataset

The project uses the HipMRI Study Open Dataset, which contains MRI scans with semantic labels for male pelvises. The data was retrieved from https://data.csiro.au/collection/csiro:51392v2?redirected=true (see reference at the end).

Dataset Structure:

HipMRI_Study_open/
├── keras_slices_data/
│   ├── keras_slices_train/          # Training images
│   ├── keras_slices_seg_train/      # Training masks
│   ├── keras_slices_validate/       # Validation images
│   ├── keras_slices_seg_validate/   # Validation masks
│   ├── keras_slices_test/           # Test images
│   └── keras_slices_seg_test/       # Test masks
└── semantic_labels_only/            # Original 3D NIfTI files

Preprocessing

  1. Image Loading: NIfTI (.nii.gz) files loaded with nibabel
  2. Normalisation: Images standardised using z-score normalisation: (x - mean) / std [1]
  3. One-Hot Encoding: Masks converted to 6-channel one-hot format [B, 6, H, W]
  4. Data Augmentation (training only) [2]:
    • Random horizontal/vertical flips (50% probability each)
    • Random rotation of ±15 degrees
    • Geometric transforms applied consistently to image-mask pairs

Train/Validation/Test Split

The dataset uses a predefined split provided by the HipMRI Study Open Dataset:

  • Training set: 11,460 images
  • Validation set: 660 images
  • Test set: 540 images

Justification:

  • The dataset came pre-split, so no manual splitting was required
  • 90/5/5 split is standard for medical imaging datasets
  • Training set is large enough to learn robust features
  • Validation set (660 samples) is sufficient
  • Test set (540 samples) provides statistically meaningful evaluation

Model Architecture

Residual U-Net

The model improves upon standard U-Net with residual blocks and batch normalisation:

Input (1 channel, grayscale MRI)
    
Encoder Path (with residual blocks):
    ResBlock(1→64) → MaxPool
    ResBlock(64→128) → MaxPool
    ResBlock(128→256) → MaxPool
    ResBlock(256→512) → MaxPool
    ResBlock(512→1024) [Bottleneck]
    
Decoder Path (with skip connections):
    UpConv + Concat → ResBlock(1024→512)
    UpConv + Concat → ResBlock(512→256)
    UpConv + Concat → ResBlock(256→128)
    UpConv + Concat → ResBlock(128→64)
    
Final Conv(64→6)
    
Output (6 channels, class logits)

Residual Block Details

Each ResidualBlock consists of:

Input
  ├─ Conv3x3 → BatchNorm → ReLU → Conv3x3 → BatchNorm → (+)
  └─ [1x1 Conv if channels mismatch] ────────────────────→ ReLU → Output

Training

Configuration

Parameter Value
Batch Size 16
Epochs 20
Learning Rate 1e-4
Optimiser Adam
Loss Function 60% Dice + 40% CrossEntropy
Device CUDA (if available) / CPU

Loss Function

The combined loss leverages strengths of both components:

  • Dice Loss: Directly optimises the evaluation metric (Dice coefficient)
  • Cross-Entropy: Provides stable pixel-wise classification gradients

Training Script

python train.py

Outputs:

  • outputs/best_model.pth - Best model checkpoint (highest validation Dice)
  • outputs/prediction_XXX.png/ - Predicted visualisations (saved PNGs)

Output Visualisations

Prediction - slice 01 Input | Ground truth | Model prediction

Prediction - slice 02

Prediction - slice 03

  • outputs/training_curves.png - Loss and Dice score plots Dice Loss Curves

Evaluation

Metrics

This model was trained on only 5 epochs, as it reaches the minimum dice coefficient of 0.75 very quickly.

Dice Coefficient (primary metric):

Dice = (2 × |Prediction ∩ Ground Truth|) / (|Prediction| + |Ground Truth|)

Calculated per-class and averaged across all 6 classes for final score.

Running Evaluation

python predict.py

Features:

  • Loads best model from outputs/best_model.pth
  • Evaluates on test set
  • Reports mean, std, min, max Dice scores
  • Saves prediction visualisations

Results

Performance Metrics

The following is an output from predict.py:

======================================================================
TRAINING COMPLETED
======================================================================
Best Validation Dice: 0.8654
======================================================================

======================================================================
TEST SET EVALUATION
======================================================================

Test Loss: 0.2150
Test Dice: 0.8777
======================================================================

Training curves saved to: outputs/training_curves.png
FINAL SUMMARY
======================================================================
Best Validation Dice: 0.8654
Test Set Dice:        0.8777
Model saved to:       ./outputs/best_model.pth
Plots saved to:       ./outputs/training_curves.png
======================================================================

The model provided an average Dice coefficient of 0.877 per label (averaged over 6 classes), which exceeds the 0.75 dice coefficient requirement for this task.

Training Curves

Training and validation loss/Dice curves are automatically saved to outputs/training_curves.png after training completes.

Project Structure

COMP3710-Report/
├── train.py                 # Training script
├── predict.py              # Evaluation script
├── modules.py              # Residual U-Net architecture
├── dataset.py              # Dataset loader with augmentations
├── utils_visualize.py      # Visualisation utilities
├── check_predictions.py    # Quick prediction checker
├── README.md               # This file
├── LICENSE                 # Project license
└── outputs/                # Training outputs
    ├── best_model.pth
    ├── training_curves.png
    └── prediction_XXX.png  #variable amount of prediction visualisations


Requirements

Python Dependencies

torch>=1.9.0
torchvision>=0.10.0
numpy>=1.19.0
nibabel>=3.2.0
matplotlib>=3.3.0
tqdm>=4.60.0
scipy>=1.5.0

Installation

pip install torch torchvision numpy nibabel matplotlib tqdm scipy

Usage

Hardware + Runtime

This project made use of UQ's Rangpur cluster, namely an a100 GPU. The following bash script was used to run it:

#!/bin/bash
#SBATCH --partition=a100
#SBATCH --gres=gpu:1
#SBATCH --job-name=hipmri
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8

#SBATCH --output=task3.out
#SBATCH --error=task3.err

conda activate torch
python train.py

Runtime Estimates

Task GPU Time CPU Time
Training (20 epochs) ~10-15 min ~1-2 hours
Evaluation (test set) ~10-30 sec ~1-2 min
Single prediction <1 sec ~1 sec

Device Selection

The code automatically detects and uses CUDA if available:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

References

COMP3710 Teaching Team, 2025. Retrieved from https://colab.research.google.com/drive/1VOsZSyRhyuHLmgoqGriQk01ub4bKNmZ1?usp=sharing

Dowling, J. & Greer, P. (2014). Labelled weekly MR images of the male pelvis. Retrieved from https://data.csiro.au/collection/csiro:51392v2?redirected=true