Skip to content

SDPM-lab/TengTengDiff

Repository files navigation

TengTengDiff

基於 DualAnoDiff 方法的工業異常檢測與生成系統。利用 Stable Diffusion 搭配 LoRA 微調,同時生成高品質的合成異常影像與對應遮罩,用於訓練下游分割模型以進行工業缺陷檢測。

環境設定

系統需求

  • Python 3.8+
  • CUDA 相容 GPU(建議 12GB+ 顯存)
  • CUDA Toolkit

安裝

# 建立虛擬環境
python -m venv .env
source .env/bin/activate

# 安裝依賴
pip install -r requirements.txt

資料集準備

本專案使用 MVTec AD 資料集,包含 15 個工業產品類別。

將資料集放置於 datasets/mvtec_ad/,結構如下:

datasets/mvtec_ad/
├── bottle/
│   ├── train/good/           # 正常訓練影像
│   ├── test/
│   │   ├── good/             # 正常測試影像
│   │   └── {defect_type}/    # 異常影像(如 broken_large, contamination)
│   └── ground_truth/
│       └── {defect_type}/    # 二值遮罩
├── hazelnut/
│   ├── ...(crack, cut, hole, print)
├── cable/
│   ├── ...
└── ...(共 15 個類別)

支援的 15 個類別: bottle, cable, capsule, carpet, grid, hazelnut, leather, metal_nut, pill, screw, tile, toothbrush, transistor, wood, zipper

模型準備

將以下預訓練模型放置於 models/ 目錄:

models/
├── stable-diffusion-v1-5/    # Stable Diffusion v1.5 基礎模型

外部依賴倉庫放置於 extra_repository/

extra_repository/
├── dinov2/                   # DINOv2 倉庫

使用方式

訓練

使用 DreamBooth + LoRA 對 Stable Diffusion 進行雙提示詞微調:

accelerate launch train/stage2-dual/train.py \
    --pretrained_model_name_or_path=models/stable-diffusion-v1-5 \
    --instance_data_dir=datasets/mvtec_ad \
    --mvtec_name=hazelnut \
    --mvtec_anamaly_name=hole \
    --output_dir=all_generate/hazelnut/stage2-hole-dual \
    --instance_prompt_blend="a vfx with sks" \
    --instance_prompt_fg="sks" \
    --resolution=512 \
    --train_batch_size=1 \
    --learning_rate=2e-5 \
    --lr_scheduler=constant \
    --max_train_steps=5000 \
    --rank=32 \
    --seed=32 \
    --gradient_checkpointing \
    --report_to=tensorboard

啟用 DINOv2 感知損失:

# 在上述命令後追加
    --use_dinov2_loss \
    --dinov2_loss_weight=1.0 \
    --dinov2_model_name=vitb14

啟用 DINOv3 感知損失:

    --use_dinov3_loss \
    --dinov3_loss_weight=0.1 \
    --dinov3_model_name=vitb16

使用腳本訓練:

bash script/train/train_stage2-dual.sh

批次訓練所有類別:

# 生成所有類別的訓練腳本
python batch_scripts/generate_train_scripts.py

推理

生成合成異常影像與遮罩配對:

python inference/inference_dual.py \
    --model_name=models/stable-diffusion-v1-5 \
    --lora_weights=all_generate/hazelnut/stage2-hole-dual/checkpoint-5000 \
    --prompt_blend="a vfx with sks" \
    --prompt_fg="sks" \
    --num_images=100 \
    --num_inference_steps=100 \
    --guidance_scale=1.0 \
    --seed=42 \
    --output_dir=generate_data/hazelnut/stage2-hole-dual/checkpoint-5000

使用腳本推理:

bash script/inference/inference_stage2-dual.sh

生成結果保存在:

  • {output_dir}/blend/ — 完整異常影像
  • {output_dir}/fg/ — 異常區域遮罩

評估

執行所有評估指標

# 方式 1:指定類別名稱
bash script/eval/eval_all.sh hazelnut generate_data datasets/mvtec_ad 0

# 方式 2:直接指定影像路徑
bash script/eval/eval_all.sh generate_data/hazelnut/stage2-hole-dual/checkpoint-5000/image datasets/mvtec_ad 0

個別指標

IC-LPIPS(Intra-Cluster LPIPS)— 衡量生成多樣性:

python eval/compute-ic-lpips.py \
    --sample_name hazelnut \
    --generate_data_path generate_data \
    --mvtec_path datasets/mvtec_ad \
    --output ic_lpips_results.csv

Inception Score — 衡量生成品質與多樣性:

python eval/compute-is.py \
    --sample_name hazelnut \
    --generate_data_path generate_data \
    --output IS_results.csv

彙整結果

python summarize_results.py \
    --generate-dir generate_data/hazelnut \
    --anomalies crack hole print cut

定位模型

使用合成資料訓練異常分割模型:

訓練分割模型:

bash script/eval/train_localization.sh

測試分割模型:

bash script/eval/test_localization.sh

定位評估指標: AUC-P(像素級 AUROC)、AP-P(像素級平均精度)、F1-P(像素級 F1)、AP-I(影像級平均精度)、AU-PRO

主要參數說明

訓練參數

參數 預設值 說明
--resolution 512 影像解析度
--train_batch_size 4 批次大小
--learning_rate 5e-4 學習率
--max_train_steps - 最大訓練步數
--rank 4 LoRA 秩(建議 32)
--seed - 隨機種子
--checkpointing_steps 500 檢查點儲存間隔
--gradient_checkpointing false 梯度檢查點(節省顯存)

感知損失參數

參數 預設值 說明
--use_dinov2_loss false 啟用 DINOv2 感知損失
--dinov2_loss_weight 0.1 DINOv2 損失權重
--dinov2_model_name vitb14 DINOv2 模型變體 (vits14/vitb14/vitl14/vitg14)
--use_dinov3_loss false 啟用 DINOv3 感知損失
--dinov3_loss_weight 0.1 DINOv3 損失權重
--dinov3_model_name vitb16 DINOv3 模型變體 (vits16/vitb16/vitl16/vith16)
--use_alignment_loss false 啟用 blend/fg 空間對齊損失
--alignment_loss_weight 0.1 對齊損失權重

推理參數

參數 預設值 說明
--num_images - 生成影像數量
--num_inference_steps 100 去噪步數
--guidance_scale 1.0 CFG 引導尺度
--seed - 隨機種子(確保 blend/fg 對齊)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors