TwoTST (Dual-Stream Self-Supervised Pretraining Framework) ๆฏไธไธช็จไบfMRIๆฐๆฎๅๆ็ๅๆตTransformerๆกๆถ๏ผ่ๅไบๆถๅบTransformerๅ่ฟๆฅTransformer็ไผ็น๏ผ้่ฟ่ช็็ฃ้ข่ฎญ็ปๅๅฏนๆฏๅญฆไน ๆๅASDๅ็ฑปๆง่ฝใ
TwoTSTๆกๆถๅ ๅซไธคไธช็ฌ็ซ็Transformerๅๆฏ๏ผ
- TST1 (Transformer-TS): ๅค็ๅๅงfMRIๆถ้ดๅบๅ๏ผไฝฟ็จROI-levelๆฉ็ ็ญ็ฅ่ฟ่ก้ข่ฎญ็ป
- TST2 (Transformer-FC): ๅค็PCC๏ผPearson็ธๅ ณ็ณปๆฐ๏ผไธไธ่งๅ้๏ผไฝฟ็จๅ ็ด ็บงๆฉ็ ็ญ็ฅ่ฟ่ก้ข่ฎญ็ป
- โ ๅๆตTransformerๆถๆ๏ผๅๅซๅค็ๆถๅบๅ่ฟๆฅ็นๅพ
- โ ่ช็็ฃ้ข่ฎญ็ป๏ผไธค็งไธๅ็ๆฉ็ ็ญ็ฅ้้ ไธๅๆฐๆฎ็ฑปๅ
- โ ้กบๅบ้ข่ฎญ็ป๏ผๅ TST1๏ผๅTST2๏ผ้ๆญฅๅญฆไน ่กจๅพ
- โ ๅฏ้ๅฏนๆฏๅญฆไน ๏ผๅฏน้ฝไธคไธชๅๆฏ็็นๅพ็ฉบ้ด
- โ ๅค็ง่ๅ็ญ็ฅ๏ผๆฏๆ5็ง็นๅพ่ๅๆนๆณ
- โ 5ๆไบคๅ้ช่ฏ๏ผ็จณๅฅ็ๆจกๅ่ฏไผฐ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ ๆฐๆฎ้ขๅค็ โ
โ fmri.npy (N, T, R) โ ๆธ
ๆด โ (N', T, R) โ
โ โ โ
โ ๆถ้ดๅบๅ (N', R, T) + PCCๅ้ (N', R*(R-1)/2) โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Phase 1: TST1 ้ข่ฎญ็ป โ
โ ๆถ้ดๅบๅ โ ROI-levelๆฉ็ โ Transformer-TS โ ้ๅปบๆถๅบ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Phase 2: TST2 ้ข่ฎญ็ป โ
โ PCCๅ้ โ ๅ
็ด ็บงๆฉ็ โ Transformer-FC โ ้ๅปบPCC โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Phase 3: ๅฏนๆฏๅญฆไน ๏ผๅฏ้๏ผ โ
โ TST1็นๅพ + TST2็นๅพ โ InfoNCEๆๅคฑ โ ็นๅพๅฏน้ฝ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Phase 4: ๅพฎ่ฐๅ็ฑป โ
โ ่ๅ็นๅพ โ MLPๅ็ฑปๅจ โ ASD/TC้ขๆต โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
TwoTST/
โโโ models/ # ๆจกๅๅฎไน
โ โโโ transformer_ts.py # TST1: ๆถๅบTransformer
โ โโโ transformer_fc.py # TST2: ่ฟๆฅTransformer
โ โโโ fusion.py # ่ๅๆจกๅ๏ผ5็ง็ญ็ฅ๏ผ
โ โโโ dual_stream.py # ๅๆตๆจกๅ
โ โโโ __init__.py # ๆจกๅๅฏผๅบ
โ
โโโ pretrain/ # ้ข่ฎญ็ปๆจกๅ
โ โโโ mask_utils.py # ๆฉ็ ็ญ็ฅๅทฅๅ
ท
โ โโโ pretrain_ts.py # TST1้ข่ฎญ็ป่ๆฌ
โ โโโ pretrain_fc.py # TST2้ข่ฎญ็ป่ๆฌ
โ โโโ contrastive.py # ๅฏนๆฏๅญฆไน ๆจกๅ
โ โโโ __init__.py
โ
โโโ scripts/ # ่ฎญ็ป่ๆฌ
โ โโโ prepare_data.py # ๆฐๆฎ้ขๅค็
โ โโโ train_pretrain.py # ้ข่ฎญ็ปๅ
ฅๅฃ
โ โโโ train_finetune.py # ๅพฎ่ฐๅ
ฅๅฃ
โ โโโ start_tensorboard.sh # TensorBoardๅฏๅจ่ๆฌ
โ
โโโ utils/ # ๅทฅๅ
ทๅฝๆฐ
โ โโโ data_loader.py # ๆฐๆฎๅ ่ฝฝๅจ
โ โโโ metrics.py # ่ฏไผฐๆๆ
โ โโโ __init__.py
โ
โโโ configs/ # ้
็ฝฎๆไปถ
โ โโโ default.yaml # ้ป่ฎค้
็ฝฎ
โ โโโ base_template.yaml # ้
็ฝฎๆจกๆฟ
โ
โโโ requirements.txt # ไพ่ตๅ
โโโ README.md # ๆฌๆๆกฃ
ๆณจๆ: ่ฎญ็ป่ฟ็จไธญไผ็ๆไปฅไธ็ฎๅฝ๏ผๅปบ่ฎฎๆทปๅ ๅฐ .gitignore๏ผ๏ผ
checkpoints/- ๆจกๅๆฃๆฅ็นlogs/- TensorBoardๆฅๅฟdata/- ๆฐๆฎๆไปถresults/- ๅฎ้ช็ปๆ
็ณป็ป่ฆๆฑ:
- Python >= 3.7
- CUDA >= 10.2 (GPUๆจ่)
ๅฎ่ฃ ไพ่ต:
# ๅ
้ไปๅบ
git clone https://github.com/Leezy-Ray/twoTST.git
cd twoTST
# ๅฎ่ฃ
ไพ่ต
pip install -r requirements.txtไธป่ฆไพ่ตๅ ๏ผ
- PyTorch >= 1.10.0
- NumPy >= 1.20.0
- scikit-learn >= 1.0.0
- TensorBoard >= 2.8.0
- tqdm, PyYAML
ๅๅคๆจ็fMRIๆฐๆฎ๏ผๆ ผๅผไธบ (n_samples, time_points, n_rois) ็numpyๆฐ็ปใ
# ่ฟ่กๆฐๆฎ้ขๅค็่ๆฌ
python scripts/prepare_data.py \
--data_path /path/to/your/fmri.npy \
--output_dir data/processed \
--n_rois 200 \
--time_points 100ๆฐๆฎ้ขๅค็ๅ่ฝ:
- โ ่ชๅจๆธ ๆดROIๅ จ้ถๆ ทๆฌ
- โ ่ฎก็ฎPCC๏ผPearson็ธๅ ณ็ณปๆฐ๏ผไธไธ่งๅ้
- โ ๅฏ้ๆปๅจ็ชๅฃๆฐๆฎๅขๅผบ
- โ ๆฐๆฎๆ ๅๅๅๅๅ๏ผ่ฎญ็ป/้ช่ฏ/ๆต่ฏ๏ผ
่พๅบๆฐๆฎๆ ผๅผ:
timeseries:(n_samples, n_rois, time_points)- ๆถ้ดๅบๅpcc_vectors:(n_samples, n_rois*(n_rois-1)/2)- PCCไธไธ่งๅ้labels:(n_samples,)- ๆ ็ญพ (0=ASD, 1=TC)
ไฝฟ็จ็ปไธๅ ฅๅฃ่ๆฌ่ฟ่ก้กบๅบ้ข่ฎญ็ป๏ผ
python scripts/train_pretrain.py \
--data_path data/processed/processed_data.pkl \
--pretrain_tst1 \
--pretrain_tst2 \
--tst1_epochs 100 \
--tst2_epochs 100 \
--batch_size 32 \
--lr 1e-4 \
--save_dir checkpoints \
--log_dir logsTST1้ข่ฎญ็ป๏ผๆถๅบTransformer๏ผ:
python pretrain/pretrain_ts.py \
--data_path data/processed/processed_data.pkl \
--epochs 100 \
--batch_size 32 \
--lr 1e-4 \
--save_dir checkpoints/tst1 \
--log_dir logs/tst1TST1็น็น:
- ่พๅ
ฅ:
(batch, n_rois, time_points)- ๆถ้ดๅบๅ - ๆฉ็ ็ญ็ฅ: ROI-levelๆฉ็ ๏ผ้ๆบๆฉ็ 25%ๆ50%็ROIๆดๅ๏ผ
- ้ข่ฎญ็ปไปปๅก: ้ๅปบ่ขซๆฉ็ ROI็ๅฎๆดๆถ้ดๅบๅ
- ๆจกๅๅๆฐ: ~19M
TST2้ข่ฎญ็ป๏ผ่ฟๆฅTransformer๏ผ:
python pretrain/pretrain_fc.py \
--data_path data/processed/processed_data.pkl \
--epochs 100 \
--batch_size 32 \
--lr 1e-4 \
--mask_ratio 0.15 \
--save_dir checkpoints/tst2 \
--log_dir logs/tst2TST2็น็น:
- ่พๅ
ฅ:
(batch, pcc_dim)- PCCไธไธ่งๅ้ - ๆฉ็ ็ญ็ฅ: ๅ ็ด ็บงๆฉ็ ๏ผ้ๆบๆฉ็ 15%็PCCๅผ๏ผ
- ้ข่ฎญ็ปไปปๅก: ้ๅปบ่ขซๆฉ็ ็PCCๅผ
- ๆจกๅๅๆฐ: ~16M
ๅ ่ฝฝ้ข่ฎญ็ปๆ้่ฟ่กไธๆธธๅ็ฑปไปปๅก๏ผ
python scripts/train_finetune.py \
--data_path data/processed/processed_data.pkl \
--tst1_checkpoint checkpoints/tst1/tst1_best.pt \
--tst2_checkpoint checkpoints/tst2/tst2_best.pt \
--fusion_type cross_attention \
--epochs 100 \
--batch_size 32 \
--lr 5e-5 \
--n_folds 5 \
--use_contrastive # ๅฏ้๏ผๅฏ็จๅฏนๆฏๅญฆไน ๅพฎ่ฐๅๆฐ่ฏดๆ:
--fusion_type: ่ๅ็ญ็ฅ (concat,gated,cross_attention,bilinear,attention_pooling)--n_folds: ไบคๅ้ช่ฏๆๆฐ๏ผ้ป่ฎค5ๆ๏ผ--use_contrastive: ๆฏๅฆๅจๅพฎ่ฐๅ่ฟ่กๅฏนๆฏๅญฆไน ๅฏน้ฝ
่ฎญ็ป่ฟ็จไธญๅฏไปฅไฝฟ็จTensorBoardๅฎๆถๆฅ็่ฎญ็ปๆฒ็บฟ๏ผ
# ๅฏๅจTensorBoardๆๅก
bash scripts/start_tensorboard.sh
# ๆๆๅจๅฏๅจ
tensorboard --logdir=logs --port=6006 --host=0.0.0.0็ถๅๅจๆต่งๅจไธญ่ฎฟ้ฎ http://localhost:6006 ๆฅ็่ฎญ็ปๆฒ็บฟใ
้
็ฝฎๆไปถไฝไบ configs/default.yaml๏ผไธป่ฆ้
็ฝฎ้กน๏ผ
# ๆฐๆฎ้
็ฝฎ
data:
n_rois: 200
time_points: 100
pcc_dim: 19900
# TST1้
็ฝฎ
tst1:
emb_dim: 512
n_heads: 8
n_layers: 6
dim_feedforward: 2048
# TST2้
็ฝฎ
tst2:
d_model: 256
n_heads: 8
n_layers: 2
dim_feedforward: 512
# ่ๅ้
็ฝฎ
fusion:
type: cross_attention # concat/gated/cross_attention/bilinear/attention_pooling
# ๅฏนๆฏๅญฆไน ้
็ฝฎ
contrastive:
enabled: false
temperature: 0.07
epochs: 50ๆกๆถๆฏๆ5็ง่ๅ็ญ็ฅ๏ผ
- ConcatFusion: ็ฎๅๆผๆฅ
[h_ts; h_fc] - GatedFusion: ้จๆง่ๅ
gate * h_ts + (1-gate) * h_fc - CrossAttentionFusion: ไบคๅๆณจๆๅ่ๅ๏ผๆจ่๏ผ
- BilinearFusion: ๅ็บฟๆง่ๅ
- AttentionPoolingFusion: ๆณจๆๅๆฑ ๅ่ๅ
ๅพฎ่ฐ่ๆฌไผ่ชๅจ่ฎก็ฎไปฅไธๆๆ ๏ผ
- Accuracy: ๅ็กฎ็
- Precision: ็ฒพ็กฎ็
- Recall: ๅฌๅ็
- F1 Score: F1ๅๆฐ
- AUC: ROCๆฒ็บฟไธ้ข็งฏ
- Sensitivity/Specificity: ๆๆๅบฆ/็นๅผๅบฆ
่พๅบ็คบไพ๏ผ
Cross-Validation Results:
----------------------------------------
Accuracy : 0.7234 ยฑ 0.0234
Precision : 0.7123 ยฑ 0.0198
Recall : 0.7345 ยฑ 0.0212
F1 : 0.7231 ยฑ 0.0201
AUC : 0.7891 ยฑ 0.0156
import torch
from models import create_dual_stream_model
from utils.data_loader import load_processed_data, TwoTSTDataset
# ๅ ่ฝฝๆฐๆฎ
data = load_processed_data('data/processed/processed_data.pkl')
dataset = TwoTSTDataset(
data['timeseries'],
data['pcc_vectors'],
data['labels']
)
# ๅๅปบๆจกๅ
model = create_dual_stream_model(
n_rois=200,
time_points=100,
pcc_dim=19900,
fusion_type='cross_attention'
)
# ๅ ่ฝฝ้ข่ฎญ็ปๆ้
model.load_pretrained_tst1('checkpoints/tst1/tst1_best.pt')
model.load_pretrained_tst2('checkpoints/tst2/tst2_best.pt')
# ๅๅไผ ๆญ
timeseries = torch.randn(8, 200, 100)
pcc_vector = torch.randn(8, 19900)
logits = model(timeseries, pcc_vector)A: ๅฏไปฅๅๅฐbatch_sizeๆไฝฟ็จๆขฏๅบฆ็ดฏ็งฏ๏ผ
--batch_size 16 # ๅๅฐbatch sizeA: ๅฏไปฅไฟฎๆน่ๆฌ๏ผๅชๅ ่ฝฝๅนถไฝฟ็จๅไธชๅๆฏ็้ข่ฎญ็ปๆ้ใ
A: ๅจๅพฎ่ฐ่ๆฌไธญๆๅฎcheckpoint่ทฏๅพ๏ผ
--tst1_checkpoint checkpoints/tst1/tst1_best.pt
--tst2_checkpoint checkpoints/tst2/tst2_best.ptA: ๅจๆฐๆฎ้ขๅค็ๆถๅฏ็จ๏ผ
python scripts/prepare_data.py \
--use_sliding_window \
--window_size 50 \
--stride 25ๆฌ้กน็ฎๅ่ไบไปฅไธๅทฅไฝ๏ผ
- ROI-levelๆฉ็ ้ข่ฎญ็ปๆถๅบTransformer
- PCCไธไธ่งๅ้ๆฉ็ ้ข่ฎญ็ป่ฟๆฅTransformer
ๆฌ้กน็ฎไป ไพ็ ็ฉถไฝฟ็จใ
ๆฌข่ฟๆไบคIssueๅPull Request๏ผ
ๅฆๆๆจๆไปปไฝ้ฎ้ขๆๅปบ่ฎฎ๏ผ่ฏท๏ผ
- ๆไบค Issue
- ๅ่ตท Pull Request
ๅฆๆ้ฎ้ข๏ผ่ฏทๆไบคIssueๆ่็ณป้กน็ฎ็ปดๆค่ ใ
TwoTST - Dual-Stream Self-Supervised Pretraining Framework for fMRI Analysis