Skip to content

Commit f0dc455

Browse files
committed
perf(dataset): 更新文件读取方式
1 parent 7ee913f commit f0dc455

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

py/lib/models/location_dataset.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import cv2
1111
import os
12+
import glob
1213
import torch
1314
from torch.utils.data import DataLoader
1415
from torch.utils.data import Dataset
@@ -35,14 +36,8 @@ def __init__(self, root_dir, cate_list, transform=None, S=7, B=2, C=20):
3536
self.C = C
3637
self.cate_list = cate_list
3738

38-
jpeg_path_list = []
39-
xml_path_list = []
40-
for name in cate_list:
41-
for i in range(1, 61):
42-
jpeg_path_list.append(os.path.join(root_dir, 'imgs', '%s_%d.jpg' % (name, i)))
43-
xml_path_list.append(os.path.join(root_dir, 'annotations', '%s_%d.xml' % (name, i)))
44-
self.jpeg_path_list = jpeg_path_list
45-
self.xml_path_list = xml_path_list
39+
self.jpeg_path_list = glob.glob(os.path.join(root_dir, 'imgs', '*.jpg'))
40+
self.xml_path_list = glob.glob(os.path.join(root_dir, 'annotations', '*.xml'))
4641

4742
def __getitem__(self, index):
4843
"""
@@ -146,8 +141,8 @@ def __len__(self):
146141
print(target.shape)
147142
print(target)
148143

149-
# data_loader = DataLoader(data_set, shuffle=True, batch_size=8, num_workers=8)
150-
# items = next(iter(data_loader))
151-
# inputs, labels = items
152-
# print(inputs.shape)
153-
# print(labels.shape)
144+
data_loader = DataLoader(data_set, shuffle=True, batch_size=8, num_workers=8)
145+
items = next(iter(data_loader))
146+
inputs, labels = items
147+
print(inputs.shape)
148+
print(labels.shape)

0 commit comments

Comments
 (0)