Skip to content

不使用DDP,只用lmdb,速度很慢,比原始imread还慢 #20

@Edwardmark

Description

@Edwardmark
def folder2lmdb(anno_file, name="train", write_frequency=5000, num_workers=16):
    ids = []
    annotation = []
    for line in open(anno_file,'r'):
        filename = line.strip().split()[0]
        ids.append(filename)
        annotation.append(line.strip().split()[1:])
    lmdb_path = osp.join("app_%s.lmdb" % name)
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir,
                   map_size=1099511627776 * 2, readonly=False,
                   meminit=False, map_async=True)
    
    print(len(ids), len(annotation))
    txn = db.begin(write=True)
    idx = 0
    for filename, label in zip(ids, annotation):
        print(filename, label)
        image = raw_reader(filename)
        txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label)))
        if idx % write_frequency == 0:
            print("[%d/%d]" % (idx, len(annotation)))
            txn.commit()
            txn = db.begin(write=True)
        idx += 1

    # finish iterating through dataset
    txn.commit()
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', dumps_pyarrow(keys))
        txn.put(b'__len__', dumps_pyarrow(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

class DetectionLMDB(data.Dataset):
    def __init__(self, db_path, transform=None, target_transform=None, dataset_name='WiderFace'):
        self.db_path = db_path
        self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            # self.length = txn.stat()['entries'] - 1
            self.length =pa.deserialize(txn.get(b'__len__'))
            self.keys= pa.deserialize(txn.get(b'__keys__'))

        self.transform = transform
        self.target_transform = target_transform


        self.name = dataset_name
        self.annotation = list()
        self.counter = 0

    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)
        return im, gt

    def pull_item(self, index):
        img, target = None, None
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])
        unpacked = pa.deserialize(byteflow)

        # load image
        imgbuf = unpacked[0]
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf).convert('RGB')
        img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)  
        height, width, channels = img.shape
        # load label
        target = unpacked[1]

        if self.target_transform is not None:
            target = self.target_transform(target, width, height)

        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels, poses, angles = self.transform(img, target[:, :4], target[:, 4], target[:,5], target[:,6])
            target = np.hstack((boxes, np.expand_dims(labels, axis=1),
                                       np.expand_dims(poses, axis=1),
                                       np.expand_dims(angles, axis=1)))

        return torch.from_numpy(img).permute(2, 0, 1), target, height, width

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'

使用上述代码生成lmdb并用DetectionLMDB作为dataset,速度很慢,不知道为啥,是不是必须跟DDP混合使用呢?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions