Skip to content

vepiset/vEpiSleepNet

Repository files navigation

vEpiSleepNet

基于 CNN-Transformer 混合架构的自动睡眠分期深度学习框架vEpiSleepNet,支持多通道 EEG 信号输入,可对 ISRUC 公开数据集及临床多导睡眠图(eeg)数据进行五期睡眠自动分类。

目录


项目简介

vEpiSleepNet 以 30 秒 EEG 片段为基本单元,实现端到端的五期睡眠自动分期(Wake / N1 / N2 / N3 / REM)。核心思路:

  • 使用预训练的 EfficientNet-B0 对 EEG 频谱图进行局部特征提取;
  • 通过**膨胀卷积(Dilated CNN)**增大感受野;
  • 再接单层 Transformer Encoder(ParallelScalingBlock)建模全局时序依赖;
  • CLS token 聚合序列表示,完成五分类输出。

此外,base_trainer/ 目录还保留了多个可替换的对比模型(DeepSleepNet、SeqSleepNet、AttnSleep、SleepTransformer 等),便于消融实验。


模型架构

输入: EEG 多通道原始信号  [B, C, 15000]
         ↓
CNNEmbedding
  ├─ EfficientNet-B0 (pretrained, 单通道输入)  → 空间特征图
  └─ Dilated CNN Head (Conv1d × 3, dilation=1/2/3)  → [B, 768, T]
         ↓
位置编码 + CLS Token  → [B, T+1, 768]
         ↓
Transformer Encoder (ParallelScalingBlock × 1, heads=8)
         ↓
LayerNorm + Dropout(0.5)
         ↓
Linear(768 → 5)  → 分类 logits
超参数
Embed Dim 768
Transformer Heads 8
Transformer Blocks 1
Drop Path 0.2
Dropout 0.5

睡眠分期标签

标签值 睡眠阶段
0 Wake(清醒期)
1 N1(非快速眼动 1 期)
2 N2(非快速眼动 2 期)
3 N3(慢波睡眠)
4 REM(快速眼动期)

ISRUC 原始标签中 REM 为 5,预处理脚本已自动映射为 4。


数据集支持

ISRUC-S1(100 名受试者,8 通道)

bash get_ISRUC_S1.sh          # 下载原始数据
python ISRUC_S1_data_preprocess.py  # 预处理 → npy 片段
  • 采样率重采样至 200 Hz,每 epoch 6000 个采样点
  • 自动识别 A 参考 / M 参考两种命名方案
  • 滤波:50 Hz 陷波 + 0.3–50 Hz 带通
  • 输出路径:./data/ISRUC_S1/PerSegment_8ch/
  • 文件命名格式:ISRUC_S3_sub{id}_{seg:03d}__{label}.npy

ISRUC-S3(10 名受试者,8通道)

bash get_ISRUC_S3.sh          # 下载原始数据
python ISRUC_S3_data_preprocess.py  # 预处理 → npy 片段
  • 采样率重采样至 100 Hz,每 epoch 3000 个采样点

临床数据

项目同时支持xx医院临床eeg数据,使用 29 导联配置(如下所示)。

通道列表:

Fp1, Fp2, F3, F4, C3, C4, P3, P4, O1, O2,
F7, F8, T3, T4, T5, T6, Fz, Cz, Pz,
PG1, PG2, A1, A2

实际送入模型的通道为 23 个 EEG 通道归一化后,再拼接 A1/A2 参考差分导联(19+19=38 维),截断/补零至长度 15000。


环境依赖

pip install -r requirements.txt
主要依赖 版本
torch 1.13.0
torchvision 0.14.0
torchaudio 0.13.0
timm 0.9.2
mne 1.4.2
numpy 1.25.2
pandas 2.0.1
scikit-learn 1.2.2

建议使用支持 CUDA 的 GPU 进行训练(代码默认使用 GPU 2,7,可在 train.py 中修改 CUDA_VISIBLE_DEVICES)。


快速开始

1. 准备数据

参照数据集支持完成数据下载与预处理,并生成包含以下字段的 CSV 索引文件:

字段 说明
file_path npy 文件绝对路径
target 睡眠分期标签(0–4)
fold 交叉验证折编号(0–4)
name EEG 记录名(用于按受试者分折)

使用 k_fold_split_by_eeg.py 生成 5 折分割:

python k_fold_split_by_eeg.py

2. 修改配置

编辑 train_config.py,更新以下路径:

config.DATA.data_file   = '/path/to/your/data.csv'
config.MODEL.model_path = '/path/to/save/models'

3. 训练

python train.py
  • 默认执行 5 折交叉验证
  • 每折每个 epoch 结束后保存 .pth 权重
  • 训练日志实时输出到控制台

4. 评估

python eval.py --weight /path/to/model.pth \
               --test_path /path/to/test.csv \
               --save_predict_csv False

--save_predict_csv True 会将预测结果写回 CSV(添加 pseudo_labelpred_prob 列)。


项目结构

vEpiSleepNet/
├── train.py                      # 训练入口(5 折交叉验证)
├── eval.py                       # 评估脚本
├── train_config.py               # 训练超参数配置
├── k_fold_split_by_eeg.py        # 按 EEG 记录生成 K-Fold 索引
├── ISRUC_S1_data_preprocess.py   # ISRUC-S1 数据预处理(8 通道)
├── ISRUC_S3_data_preprocess.py   # ISRUC-S3 数据预处理(单通道)
├── get_ISRUC_S1.sh               # ISRUC-S1 数据集下载脚本
├── get_ISRUC_S3.sh               # ISRUC-S3 数据集下载脚本
├── requirements.txt              # Python 依赖列表
├── base_trainer/
│   ├── transformer.py            # 主模型:EfficientNet + Transformer
│   ├── model.py                  # 备选模型:ResNet1D+LSTM、TinyConvNet
│   ├── dataietr.py               # Dataset / DataLoader 实现
│   ├── net_work.py               # 训练循环(Train 类)
│   ├── metric.py                 # 评估指标(F1、混淆矩阵、ROC-AUC)
│   ├── attnsleep.py              # 对比模型:AttnSleep
│   ├── deepsleepnet.py           # 对比模型:DeepSleepNet
│   ├── seqsleepnet.py            # 对比模型:SeqSleepNet
│   ├── mmasleepnet.py            # 对比模型:MMASleepNet
│   ├── sleeptransformer.py       # 对比模型:SleepTransformer
│   └── msa_cnn.py                # 对比模型:MSA-CNN
└── utils/
    ├── logger.py                 # 日志工具
    └── seed_utils.py             # 随机种子固定

训练配置

主要超参数均在 train_config.py 中集中管理:

参数 默认值 说明
batch_size 64 批大小
epoch 20 最大训练轮数
init_lr 0.0005 初始学习率
lr_scheduler cos 学习率调度器(cos / ReduceLROnPlateau)
warmup_step 1500 Warmup 步数
opt Adamw 优化器
weight_decay_factor 0.01 L2 正则化系数
gradient_clip 5 梯度裁剪阈值
early_stop 20 早停耐心轮数
k_fold 5 交叉验证折数
num_classes 2 输出类别数(实际模型输出 5 类)
SEED 10086 随机种子

评估

训练过程中每个 epoch 自动计算并打印以下指标:

  • Macro F1-score(主要指标,用于模型保存判断)
  • Macro Precision / Recall
  • 5×5 混淆矩阵(Wake / N1 / N2 / N3 / REM)

eval.py 额外支持:

  • ROC 曲线与 AUC(二分类场景)
  • Precision-Recall 曲线
  • 特异性-灵敏度曲线(Specificity-Sensitivity)
  • 伪标签生成(将预测结果写回 CSV 供半监督训练使用)

About

Multi-channel vision Transformer-based automated sleep staging for EEG monitoring

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors