Skip to content

Helium-327/DCLA-UNet-3D

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

3 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

DCLA-UNet: Dynamic Cross Layer Attention U-Net for Medical Image Segmentation

Python PyTorch License

πŸ“– Abstract

DCLA-UNet introduces a Dynamic Cross Layer Attention (DCLA) mechanism that dynamically selects the most relevant features from different layers of the network. This attention mechanism enhances the feature representation capability of U-Net for medical image segmentation tasks, particularly for brain tumor segmentation in BraTS datasets.

πŸ› οΈ Overview

DCLA_UNet_Overview

Results

Comparison Results 2D Comparison Results 3D

πŸš€ Quick Start

Prerequisites

  • Python 3.10+
  • CUDA-capable GPU (recommended)
  • 16GB+ RAM

Installation

  1. Clone the repository
git clone git@github.com:Helium-327/DCLA-UNet-3D.git
cd DCLA-UNet-3D
  1. Install dependencies
pip install -r requirements.txt

Quick Training Example

# Train DCLA-UNet on BraTS2020 dataset
python src/main.py --model_name "DCLA_UNet_final" \
                   --slb_project "my_experiment" \
                   --datasets "BraTS2020" \
                   --data_root "/path/to/BraTS2020/raw" \
                   --epochs 100 \
                   --batch_size 2 \
                   --lr 3e-4

πŸ“Š Dataset

Data Source and License

The BraTS (Brain Tumor Segmentation) datasets used in this project are publicly available from the following sources:

Dataset Source License
BraTS 2019 CBICA - BraTS 2019 CC BY 4.0
BraTS 2020 CBICA - BraTS 2020 CC BY 4.0
BraTS 2021 RSNA-ASNR-MICCAI BraTS 2021 CC BY 4.0

Important: Registration on the respective challenge websites is required to download the datasets. Users of these datasets should cite the original BraTS challenge papers (see Citation section).

Supported Datasets

  • BraTS2019: Brain Tumor Segmentation Challenge 2019
  • BraTS2020: Brain Tumor Segmentation Challenge 2020
  • BraTS2021: Brain Tumor Segmentation Challenge 2021

Dataset Structure

The dataset should be organized as follows:

data/
β”œβ”€β”€ BraTS2020/
β”‚   β”œβ”€β”€ raw/
β”‚   β”‚   β”œβ”€β”€ BraTS20_Training_001/
β”‚   β”‚   β”‚   β”œβ”€β”€ BraTS20_Training_001_flair.nii.gz
β”‚   β”‚   β”‚   β”œβ”€β”€ BraTS20_Training_001_t1.nii.gz
β”‚   β”‚   β”‚   β”œβ”€β”€ BraTS20_Training_001_t1ce.nii.gz
β”‚   β”‚   β”‚   β”œβ”€β”€ BraTS20_Training_001_t2.nii.gz
β”‚   β”‚   β”‚   └── BraTS20_Training_001_seg.nii.gz
β”‚   β”‚   └── ...
β”‚   β”œβ”€β”€ train.csv
β”‚   β”œβ”€β”€ val.csv
β”‚   └── test.csv

Dataset Preparation

  1. Download BraTS dataset from the official website
  2. Extract and organize the data according to the structure above
  3. Generate CSV files for train/val/test splits:
python src/main.py --data_split --datasets "BraTS2020" --data_root "/path/to/BraTS2020/raw"

πŸ—οΈ Architecture

Supported Models

SOTA Models

  • UNet3D: Standard 3D U-Net
  • AttUNet3D: Attention U-Net 3D
  • UNETR: Vision Transformer for medical segmentation
  • UNETR_PP: Enhanced UNETR
  • SegFormer3D: 3D SegFormer
  • Mamba3d: State Space Model for 3D segmentation
  • MogaNet: Multi-order Gated Aggregation Network

DCLA-UNet Variants

  • DCLA_UNet_final: Main DCLA-UNet model
  • BaseLine_S_DCLA_final: Baseline with DCLA
  • BaseLine_S_DCLA_SLK_final: DCLA + Selective Large Kernel
  • BaseLine_S_DCLA_MSF_final: DCLA + Multi-Scale Fusion

Key Features

  • Dynamic Cross Layer Attention (DCLA): Adaptively selects relevant features across different network layers
  • Multi-Scale Fusion (MSF): Integrates features at multiple scales
  • Selective Large Kernel (SLK): Enhances receptive field with efficient computation
  • Mixed Precision Training: Accelerated training with automatic mixed precision

πŸ”§ Training

Basic Training

python src/main.py --model_name "DCLA_UNet_final" \
                   --slb_project "experiment_name" \
                   --datasets "BraTS2020" \
                   --data_root "/path/to/data" \
                   --epochs 100 \
                   --batch_size 2 \
                   --lr 3e-4 \
                   --wd 2e-5

Advanced Training Options

python src/main.py --model_name "DCLA_UNet_final" \
                   --slb_project "advanced_experiment" \
                   --datasets "BraTS2021" \
                   --data_root "/path/to/data" \
                   --epochs 200 \
                   --batch_size 4 \
                   --lr 3e-4 \
                   --wd 2e-5 \
                   --cosine_eta_min 1e-6 \
                   --cosine_T_max 100 \
                   --early_stop_patience 60 \
                   --slb \
                   --tb

Batch Training Script

Use the provided shell script for training multiple models:

# Edit run.sh to specify models and parameters
./run.sh

Resume Training

python src/main.py --resume "/path/to/checkpoint.pth" \
                   --model_name "DCLA_UNet_final" \
                   --slb_project "resumed_experiment"

Training Parameters

Parameter Description Default Recommended
--lr Learning rate 3e-4 1e-4 to 5e-4
--wd Weight decay 1e-5 1e-5 to 2e-5
--batch_size Batch size 1 2-4 (depends on GPU)
--epochs Training epochs 100 100-200
--early_stop_patience Early stopping patience 60 60-100
--cosine_T_max Cosine scheduler T_max 50 Half of epochs

πŸ“ˆ Monitoring and Visualization

Training Monitoring

  • SwanLab: Use --slb flag for experiment tracking
  • TensorBoard: Use --tb flag for TensorBoard logging

πŸ”¬ Evaluation Metrics

The framework supports comprehensive evaluation metrics:

  • Dice Score: Overlap-based similarity measure
  • Hausdorff Distance: Boundary-based distance measure
  • Sensitivity: True positive rate
  • Specificity: True negative rate
  • Volume Similarity: Volume-based comparison

πŸ› οΈ How to Add a New Model?

  1. Create model file in src/nnArchitecture/nets/:
# src/nnArchitecture/nets/your_model.py
import torch.nn as nn

class YourModel(nn.Module):
    def __init__(self, in_channels=4, out_channels=4):
        super().__init__()
        # Your model implementation
        
    def forward(self, x):
        # Forward pass
        return x
  1. Register model in src/model_registry.py:
from nnArchitecture.nets.your_model import YourModel

model_register = {
    "Your Models": {
        "YourModel": YourModel(**BASE_ARGS),
    }
}
  1. Update imports in src/nnArchitecture/nets/__init__.py:
from .your_model import YourModel

πŸ“ How to Add a New Dataset?

  1. Create dataset directory: src/datasets/YourDataset/

  2. Implement dataset class:

# src/datasets/YourDataset/your_dataset.py
from torch.utils.data import Dataset

class YourDataset(Dataset):
    def __init__(self, data_file, transform=None):
        # Dataset initialization
        
    def __getitem__(self, idx):
        # Return data sample
        
    def __len__(self):
        # Return dataset length
  1. Update main.py to support your dataset:
if args.datasets == 'YourDataset':
    from datasets.YourDataset import YourDataset as Dataset

πŸ“‹ Project Structure

DCLA-UNet/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ datasets/           # Dataset implementations
β”‚   β”œβ”€β”€ nnArchitecture/     # Model architectures
β”‚   β”œβ”€β”€ utils/              # Utility functions
β”‚   β”œβ”€β”€ main.py            # Main training script
β”‚   β”œβ”€β”€ train_swanlab.py   # Training with SwanLab
β”‚   β”œβ”€β”€ metrics.py         # Evaluation metrics
β”‚   └── lossFunc.py        # Loss functions
β”œβ”€β”€ visual/                # Visualization tools
β”œβ”€β”€ requirements.txt       # Dependencies
β”œβ”€β”€ run.sh                # Training script
└── README.md             # This file

πŸ› Troubleshooting

Common Issues

  1. CUDA out of memory:

    • Reduce batch size: --batch_size 1
    • Use gradient checkpointing
    • Reduce input size
  2. Dataset loading errors:

    • Check file paths in CSV files
    • Verify data directory structure
    • Ensure all required modalities are present
  3. Training instability:

    • Reduce learning rate: --lr 1e-4
    • Increase weight decay: --wd 2e-5
    • Use gradient clipping

Performance Optimization

  • Use --num_workers 8 for faster data loading
  • Enable mixed precision training (automatically enabled)
  • Use persistent_workers=True for DataLoader

πŸ“„ Citation

If this code is used in research, please cite:

@article{xiong2025dcla,
  title={DCLA-UNet: Dynamic Cross Layer Attention U-Net for Medical Image Segmentation},
  author={Xiong, Junyin},
  journal={PeerJ Computer Science},
  year={2025}
}

BraTS Dataset Citations

This project uses the BraTS Challenge datasets. The relevant citations are listed below:

BraTS 2019:

@inproceedings{bakas2019brats2019,
  title={Identifying the Best Machine Learning Algorithms for Brain Tumor Segmentation, Progression Assessment, and Overall Survival Prediction in the BRATS Challenge},
  author={Bakas, Spyridon and Reyes, Mauricio and Jakab, Andras and others},
  booktitle={International MICCAI Brainlesion Workshop},
  pages={416--428},
  year={2019},
  organization={Springer}
}

BraTS 2021:

@article{baid2021rsna,
  title={The RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor Segmentation and Radiogenomic Classification},
  author={Baid, Ujjwal and Ghodasara, Satyam and Mohan, Suyash and others},
  journal={arXiv preprint arXiv:2107.02314},
  year={2021}
}

πŸ“œ License

This project is licensed under the MIT License - see the LICENSE file for details.

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

πŸ“ž Contact

For questions and support, please contact:


About

DCLA-UNet introduces a **Dynamic Cross Layer Attention (DCLA)** mechanism that dynamically selects the most relevant features from different layers of the network. This attention mechanism enhances the feature representation capability of U-Net for medical image segmentation tasks, particularly for brain tumor segmentation in BraTS datasets.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors