-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Description
def folder2lmdb(anno_file, name="train", write_frequency=5000, num_workers=16):
ids = []
annotation = []
for line in open(anno_file,'r'):
filename = line.strip().split()[0]
ids.append(filename)
annotation.append(line.strip().split()[1:])
lmdb_path = osp.join("app_%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)
print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
print(len(ids), len(annotation))
txn = db.begin(write=True)
idx = 0
for filename, label in zip(ids, annotation):
print(filename, label)
image = raw_reader(filename)
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label)))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(annotation)))
txn.commit()
txn = db.begin(write=True)
idx += 1
# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
class DetectionLMDB(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None, dataset_name='WiderFace'):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
# self.length = txn.stat()['entries'] - 1
self.length =pa.deserialize(txn.get(b'__len__'))
self.keys= pa.deserialize(txn.get(b'__keys__'))
self.transform = transform
self.target_transform = target_transform
self.name = dataset_name
self.annotation = list()
self.counter = 0
def __getitem__(self, index):
im, gt, h, w = self.pull_item(index)
return im, gt
def pull_item(self, index):
img, target = None, None
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = pa.deserialize(byteflow)
# load image
imgbuf = unpacked[0]
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
height, width, channels = img.shape
# load label
target = unpacked[1]
if self.target_transform is not None:
target = self.target_transform(target, width, height)
if self.transform is not None:
target = np.array(target)
img, boxes, labels, poses, angles = self.transform(img, target[:, :4], target[:, 4], target[:,5], target[:,6])
target = np.hstack((boxes, np.expand_dims(labels, axis=1),
np.expand_dims(poses, axis=1),
np.expand_dims(angles, axis=1)))
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
使用上述代码生成lmdb并用DetectionLMDB作为dataset,速度很慢,不知道为啥,是不是必须跟DDP混合使用呢?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels