-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
88 lines (70 loc) · 3.2 KB
/
test.py
File metadata and controls
88 lines (70 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import argparse
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset1 import FlameDataset
from model.unet import UNet
# ========== IoU & Dice 计算 ==========
def compute_iou(pred, target, threshold=0.5):
pred = (pred > threshold).float()
target = (target > 0.5).float()
intersection = (pred * target).sum()
union = pred.sum() + target.sum() - intersection + 1e-8
return (intersection + 1e-8) / union
def compute_dice(pred, target, threshold=0.5):
pred = (pred > threshold).float()
target = (target > 0.5).float()
intersection = (pred * target).sum()
return (2. * intersection + 1e-8) / (pred.sum() + target.sum() + 1e-8)
# ========== 主函数 ==========
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() and args.gpu >= 0 else "cpu")
print(f"🔍 使用设备:{device}")
print(f"🔍 加载模型自 {args.checkpoint}")
# 加载模型
model = UNet(in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
model.eval()
# 加载测试集
test_image_dir = os.path.join(args.dataset_root, "test/images")
test_mask_dir = os.path.join(args.dataset_root, "test/masks")
test_dataset = FlameDataset(test_image_dir, test_mask_dir, image_size=(256, 256))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
os.makedirs(args.output_dir, exist_ok=True)
iou_scores = []
dice_scores = []
print("🚀 开始测试...")
with torch.no_grad():
for i, (image, mask) in enumerate(tqdm(test_loader)):
image = image.to(device)
mask = mask.to(device)
pred = model(image)
pred_sigmoid = torch.sigmoid(pred)
iou = compute_iou(pred_sigmoid, mask)
dice = compute_dice(pred_sigmoid, mask)
iou_scores.append(iou.item())
dice_scores.append(dice.item())
# 保存预测图
if args.save_pred:
pred_np = (pred_sigmoid.squeeze().cpu().numpy() * 255).astype(np.uint8)
out_path = os.path.join(args.output_dir, f"pred_{i:04d}.png")
Image.fromarray(pred_np).save(out_path)
# 汇总结果
mean_iou = np.mean(iou_scores)
mean_dice = np.mean(dice_scores)
print("\n📊 测试集评估结果:")
print(f" 🔹 Mean IoU : {mean_iou:.4f}")
print(f" 🔹 Mean Dice: {mean_dice:.4f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test U-Net model on Flame_256 test set")
parser.add_argument('--checkpoint', type=str, default='checkpoints/best_model.pth', help='Path to model weights')
parser.add_argument('--dataset_root', type=str, default='/root/autodl-tmp/dataset/Flame_256', help='Dataset root directory')
parser.add_argument('--output_dir', type=str, default='predictions', help='Where to save predictions')
parser.add_argument('--gpu', type=int, default=-1, help='Use GPU ID or -1 for CPU')
parser.add_argument('--save_pred', action='store_true', help='Whether to save predicted masks')
args = parser.parse_args()
main(args)