AI-assisted breast cancer detection on screening mammography β RSNA Kaggle Competition.
Collaborative project β M2 Biologie-Informatique, UniversitΓ© Paris CitΓ©
Encadrants : Tatiana Galochkina Β· FrΓ©dΓ©ric Guyon Β· Jean-Christophe Gelly
Assa Diabira β M2 Biologie-Informatique, UniversitΓ© Paris CitΓ©
Multi-Head Expert model: 4 medical backbones + cross-attention fusion
Breast_Cancer_Detection/
βββ main.py β entry point : train / eval / infer
βββ inference.py β load model + predict on DICOM files
βββ requirements.txt
βββ README.md
βββ .gitignore
β
βββ core/ β shared utilities
β βββ configuration.py
β βββ dataset_manager.py
β βββ loader.py
β
βββ preprocess/ β DICOM preprocessing pipeline
β βββ pipeline.py β PreprocessPipeline (5 modes)
β βββ cropping.py β ROI crop + pectoral muscle removal
β βββ windowing.py β adaptive windowing by BI-RADS density
β βββ resampler.py β isotropic resampling
β
βββ models/ β Multi-Head Expert model
β βββ multi_head_expert.py β 4 expert backbones + fusion
β βββ baseline_cnn.py β baseline for comparison
β βββ losses.py β FocalAUCLoss (70% Focal + 30% AUC)
β βββ trainer.py β Trainer class (OOP)
β βββ dataset.py β MammographyDataset (PyTorch)
β
βββ notebooks/
β βββ eda.ipynb β Exploratory Data Analysis
β βββ preprocessing_benchmark.ipynb
β βββ training_baseline.ipynb
β βββ training_multihead.ipynb
β
βββ results/
βββ metrics/ β JSON metrics from Kaggle runs
βββ figures/ β ROC curves, confusion matrices
DICOM image
β PreprocessPipeline (crop β adaptive windowing β resize)
float32 [1, H, W] (grayscale, values in [0, 1])
β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β 4 EXPERT BACKBONES β
β 1. EfficientNetV2-S mammoscreen (RSNA breast, AUC 0.945) β
β 2. DenseNet121 TorchXRayVision RSNA X-ray β
β 3. ResNet50 RadImageNet (1.35M medical images) β
β 4. ConvNeXt-Small ImageNet-21k (RSNA Kaggle winner) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β each β 512-dim embedding
Expert-Aware Fusion:
cross-attention (4 heads) β per-expert MLP β dynamic gating (softmax)
β
MLP classifier: 512 β 256 β 128 β 1 [GELU, Dropout]
β
BCEWithLogitsLoss (no Sigmoid on model output)
106M parameters total. Freeze backbones for phase 1 β 6.3M trainable.
git clone https://github.com/assadiab/Breast_Cancer_Detection.git
cd Breast_Cancer_Detection
pip install -r requirements.txtKey dependencies: torch, torchvision, timm, torchxrayvision, monai, huggingface_hub, safetensors, albumentations, pydicom, opencv-python, scikit-learn.
Optional β RadImageNet weights for Expert 3 (ResNet50):
Download from Google Drive and place in checkpoints/radImageNet/.
Expert 1 (mammoscreen) and Expert 2 (TorchXRayVision) download automatically on first run.
# Multi-Head Expert model
python main.py train \
--csv train.csv \
--images-dir /path/to/dicom/images \
--preprocess-mode full \
--epochs 20 --batch-size 8 --lr 1e-4 \
--checkpoint-dir checkpoints/
# Baseline CNN
python main.py train \
--csv train.csv \
--images-dir /path/to/dicom/images \
--model baseline --epochs 30 \
--checkpoint-dir checkpoints/baseline/from models.multi_head_expert import MultiHeadMammoModel
from models.trainer import Trainer
model = MultiHeadMammoModel(embed_dim=512)
# Phase 1 β frozen backbones (6.3M trainable params)
model.freeze_backbones()
Trainer(model, train_loader, val_loader, device, lr=1e-3, n_epochs=10).train()
# Phase 2 β unfreeze last 2 blocks
model.unfreeze_backbones(last_n_blocks=2)
Trainer(model, train_loader, val_loader, device, lr=1e-4, n_epochs=10).train()
# Phase 3 β full fine-tuning (106M params)
model.unfreeze_all()
Trainer(model, train_loader, val_loader, device, lr=5e-5, n_epochs=10).train()from preprocess.pipeline import PreprocessPipeline
pipeline = PreprocessPipeline(config, mode="full", target_hw=(1024, 512))
# modes: "raw" | "crop_only" | "window_only" | "full" | "full_iso"
img = pipeline.process_one(patient_id, image_id, laterality, view, density, dicom_path)
# returns float32 numpy (H, W), values in [0, 1]# Evaluate on validation set
python main.py eval \
--csv val.csv \
--images-dir /path/to/dicom \
--checkpoint checkpoints/best_model.pth
# Predict on new DICOM files
python main.py infer \
--images-dir /path/to/new/dicoms \
--checkpoint checkpoints/best_model.pth \
--output predictions.csvThree stages applied before model input:
1. ROI Cropping β Otsu thresholding β morphological ops β pectoral muscle removal (Hough lines, MLO views) β standardized left orientation β bounding box with mm margins.
2. Adaptive Windowing β Density-aware: percentile clipping β gamma correction β CLAHE (relative tile size) β weighted fusion. Tuned per BI-RADS density (A/B/C/D).
3. Resize to target resolution (default 1024Γ512).
| Model | ROC-AUC | PR-AUC | F1 | Recall |
|---|---|---|---|---|
| EfficientNet-B0 | 0.63 | 0.14 | 0.13 | 0.18 |
| ConvNeXt-Base | 0.62 | 0.15 | 0.20 | 0.26 |
| ResNet-50 + Meta | 0.59 | 0.13 | 0.19 | 0.21 |
| Multi-Head v1 | 0.66 | 0.15 | 0.22 | 0.23 |
| Multi-Head v3 | in progress β Kaggle GPU | β | β | β |
Metrics from Kaggle GPU runs β results/metrics/
RSNA Screening Mammography Breast Cancer Detection β 54,706 DICOM images, 11,913 patients, 2.12% cancer rate.
Patient-wise stratified split (no data leakage): 70% train / 15% val / 15% test.
Class imbalance: pos_weight=13.7 + FocalLoss (Ξ³=2.5, Ξ±=0.75) + patient-aware oversampling.
Data is not versioned here (DICOM files too large for GitHub).
- mammoscreen : https://huggingface.co/ianpan/mammoscreen
- TorchXRayVision : https://github.com/mlmed/torchxrayvision
- RadImageNet : https://pubs.rsna.org/doi/full/10.1148/ryai.210315
- RSNA Kaggle winner (ConvNeXt) : https://pmc.ncbi.nlm.nih.gov/articles/PMC11048882/