|
21 | 21 |
|
22 | 22 | S = 7 |
23 | 23 | B = 2 |
24 | | -C = 3 |
| 24 | +# C = 3 |
| 25 | +# cate_list = ['cucumber', 'eggplant', 'mushroom'] |
| 26 | + |
| 27 | +C = 20 |
| 28 | +cate_list = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', |
| 29 | + 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] |
25 | 30 |
|
26 | 31 |
|
27 | 32 | def load_data(img_path, xml_path): |
@@ -63,7 +68,9 @@ def load_data(img_path, xml_path): |
63 | 68 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
64 | 69 | device = "cpu" |
65 | 70 |
|
66 | | - img, data_dict = load_data('../imgs/cucumber_9.jpg', '../imgs/cucumber_9.xml') |
| 71 | + # img, data_dict = load_data('../imgs/cucumber_9.jpg', '../imgs/cucumber_9.xml') |
| 72 | + # img, data_dict = load_data('../imgs/000012.jpg', '../imgs/000012.xml') |
| 73 | + img, data_dict = load_data('../imgs/000007.jpg', '../imgs/000007.xml') |
67 | 74 | model = file.load_model(device, S, B, C) |
68 | 75 | # 计算 |
69 | 76 | outputs = model.forward(img.to(device)).cpu().squeeze(0) |
@@ -92,8 +99,8 @@ def load_data(img_path, xml_path): |
92 | 99 | # 预测边界框的缩放,回到原始图像 |
93 | 100 | pred_bboxs = util.deform_bboxs(pred_cate_bboxs, data_dict, S) |
94 | 101 | # 在原图绘制标注边界框和预测边界框 |
95 | | - dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], data_dict['name_list'], pred_bboxs, pred_cates, |
96 | | - pred_cate_probs) |
| 102 | + dst = draw.plot_bboxs(data_dict['src'], data_dict['bndboxs'], data_dict['name_list'], cate_list, |
| 103 | + pred_bboxs, pred_cates, pred_cate_probs) |
97 | 104 | cv2.imwrite('./detect.png', dst) |
98 | 105 | # BGR -> RGB |
99 | 106 | dst = cv2.cvtColor(dst, cv2.COLOR_BGR2RGB) |
|
0 commit comments