Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import argparse
import os
import random
import shutil
import time
import warnings

import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import warnings
import os

model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
Expand Down Expand Up @@ -68,7 +67,7 @@
def main():
global args, best_acc1
args = parser.parse_args()

if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
Expand All @@ -78,17 +77,17 @@ def main():
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')

if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')

args.distributed = args.world_size > 1

if args.distributed:
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)

# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
Expand All @@ -108,14 +107,14 @@ def main():
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(args.gpu)

optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)

# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
Expand All @@ -131,27 +130,27 @@ def main():
print("=> no checkpoint found at '{}'".format(args.resume))

cudnn.benchmark = True

# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))

traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
Expand All @@ -165,11 +164,11 @@ def main():
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)

if args.evaluate:
validate(val_loader, model, criterion)
return

for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
Expand All @@ -191,8 +190,8 @@ def main():
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
}, is_best)


def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
Expand Down Expand Up @@ -239,7 +238,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
data_time=data_time, loss=losses, top1=top1, top5=top5))


def validate(val_loader, model, criterion):
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
torch
torchvision
lmdb
pyarrow
lmdb
82 changes: 37 additions & 45 deletions tools/folder2lmdb.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,45 @@
import os
import os.path as osp
import os, sys
import os.path as osp
from PIL import Image
import six
import string

import lmdb
import pickle
import msgpack
import tqdm
import pyarrow as pa

import torch
import torch.utils.data as data
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torchvision import transforms, datasets


class ImageFolderLMDB(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
# self.length = txn.stat()['entries'] - 1
self.length =pa.deserialize(txn.get(b'__len__'))
self.keys= pa.deserialize(txn.get(b'__keys__'))

self.transform = transform
self.target_transform = target_transform

env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with env.begin(write=False) as txn:
self.length = pickle.loads(txn.get(b'__len__'))
self.keys = pickle.loads(txn.get(b'__keys__'))

def open_lmdb(self):
self.env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
self.txn = self.env.begin(write=False, buffers=True)

def __getitem__(self, index):
if not hasattr(self, 'txn'):
self.open_lmdb()

img, target = None, None
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = pa.deserialize(byteflow)
byteflow = self.txn.get(self.keys[index])
unpacked = pickle.loads(byteflow)

# load image
imgbuf = unpacked[0]
buf = six.BytesIO()
buf.write(imgbuf)
buf.write(imgbuf[0])
buf.seek(0)
img = Image.open(buf).convert('RGB')

Expand All @@ -65,7 +60,6 @@ def __len__(self):
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'


class ImageFolderLMDB_old(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
import lmdb
Expand Down Expand Up @@ -113,43 +107,42 @@ def __len__(self):
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'


def raw_reader(path):
with open(path, 'rb') as f:
bin_data = f.read()
return bin_data


def dumps_pyarrow(obj):
def dumps_pickle(obj):
"""
Serialize an object.

Returns:
Implementation-dependent bytes-like object
Returns :
The pickled representation of the object obj as a bytes object
"""
return pa.serialize(obj).to_buffer()

return pickle.dumps(obj)

def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=16):
def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=0):
directory = osp.expanduser(osp.join(dpath, name))
print("Loading dataset from %s" % directory)
dataset = ImageFolder(directory, loader=raw_reader)
data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x)
data_loader = DataLoader(dataset, num_workers=num_workers)

lmdb_path = osp.join(dpath, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)

print("Generate LMDB to %s" % lmdb_path)
map_size = 30737418240 # this should be adjusted based on OS/db size
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
map_size=map_size, readonly=False,
meminit=False, map_async=True)

print(len(dataset), len(data_loader))
txn = db.begin(write=True)
for idx, data in enumerate(data_loader):
for idx, (data, label) in enumerate(data_loader):
# print(type(data), data)
image, label = data[0]
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label)))
image = data
label = label.numpy()
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pickle((image, label)))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(data_loader)))
txn.commit()
Expand All @@ -159,21 +152,20 @@ def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=16):
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
txn.put(b'__keys__', dumps_pickle(keys))
txn.put(b'__len__', dumps_pickle(len(keys)))

print("Flushing database ...")
db.sync()
db.close()


if __name__ == "__main__":
if __name__=='__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--folder", type=str)
parser.add_argument('-s', '--split', type=str, default="val")
parser.add_argument('--out', type=str, default=".")
parser.add_argument('-p', '--procs', type=int, default=20)
parser.add_argument('-p', '--procs', type=int, default=0)

args = parser.parse_args()

Expand Down