diff --git a/main.py b/main.py index a8984db..81d8bef 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,20 @@ import argparse -import os -import random import shutil import time -import warnings - +import random import torch import torch.nn as nn import torch.nn.parallel +import torch.optim import torch.backends.cudnn as cudnn import torch.distributed as dist -import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models +import warnings +import os model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") @@ -68,7 +67,7 @@ def main(): global args, best_acc1 args = parser.parse_args() - + if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) @@ -78,17 +77,17 @@ def main(): 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') - + if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') - + args.distributed = args.world_size > 1 - + if args.distributed: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size) - + # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) @@ -108,14 +107,14 @@ def main(): model.cuda() else: model = torch.nn.DataParallel(model).cuda() - + # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): @@ -131,27 +130,27 @@ def main(): print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True - + # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - + train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None - + train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) @@ -165,11 +164,11 @@ def main(): ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) - + if args.evaluate: validate(val_loader, model, criterion) return - + for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) @@ -191,8 +190,8 @@ def main(): 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best) - - + + def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() @@ -239,7 +238,7 @@ def train(train_loader, model, criterion, optimizer, epoch): 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, top1=top1, top5=top5)) + data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(val_loader, model, criterion): diff --git a/requirements.txt b/requirements.txt index eca22b4..6e04fd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ torch torchvision -lmdb -pyarrow +lmdb \ No newline at end of file diff --git a/tools/folder2lmdb.py b/tools/folder2lmdb.py index 65c2af4..f25ed84 100755 --- a/tools/folder2lmdb.py +++ b/tools/folder2lmdb.py @@ -1,50 +1,45 @@ import os import os.path as osp -import os, sys -import os.path as osp -from PIL import Image import six -import string - import lmdb import pickle import msgpack -import tqdm -import pyarrow as pa - -import torch import torch.utils.data as data +from PIL import Image from torch.utils.data import DataLoader -from torchvision.transforms import transforms from torchvision.datasets import ImageFolder -from torchvision import transforms, datasets - class ImageFolderLMDB(data.Dataset): def __init__(self, db_path, transform=None, target_transform=None): self.db_path = db_path - self.env = lmdb.open(db_path, subdir=osp.isdir(db_path), - readonly=True, lock=False, - readahead=False, meminit=False) - with self.env.begin(write=False) as txn: - # self.length = txn.stat()['entries'] - 1 - self.length =pa.deserialize(txn.get(b'__len__')) - self.keys= pa.deserialize(txn.get(b'__keys__')) - self.transform = transform self.target_transform = target_transform + + env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path), + readonly=True, lock=False, + readahead=False, meminit=False) + with env.begin(write=False) as txn: + self.length = pickle.loads(txn.get(b'__len__')) + self.keys = pickle.loads(txn.get(b'__keys__')) + + def open_lmdb(self): + self.env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path), + readonly=True, lock=False, + readahead=False, meminit=False) + self.txn = self.env.begin(write=False, buffers=True) def __getitem__(self, index): + if not hasattr(self, 'txn'): + self.open_lmdb() + img, target = None, None - env = self.env - with env.begin(write=False) as txn: - byteflow = txn.get(self.keys[index]) - unpacked = pa.deserialize(byteflow) + byteflow = self.txn.get(self.keys[index]) + unpacked = pickle.loads(byteflow) # load image imgbuf = unpacked[0] buf = six.BytesIO() - buf.write(imgbuf) + buf.write(imgbuf[0]) buf.seek(0) img = Image.open(buf).convert('RGB') @@ -65,7 +60,6 @@ def __len__(self): def __repr__(self): return self.__class__.__name__ + ' (' + self.db_path + ')' - class ImageFolderLMDB_old(data.Dataset): def __init__(self, db_path, transform=None, target_transform=None): import lmdb @@ -113,43 +107,42 @@ def __len__(self): def __repr__(self): return self.__class__.__name__ + ' (' + self.db_path + ')' - def raw_reader(path): with open(path, 'rb') as f: bin_data = f.read() return bin_data - -def dumps_pyarrow(obj): +def dumps_pickle(obj): """ Serialize an object. - - Returns: - Implementation-dependent bytes-like object + + Returns : + The pickled representation of the object obj as a bytes object """ - return pa.serialize(obj).to_buffer() - + return pickle.dumps(obj) -def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=16): +def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=0): directory = osp.expanduser(osp.join(dpath, name)) print("Loading dataset from %s" % directory) dataset = ImageFolder(directory, loader=raw_reader) - data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x) + data_loader = DataLoader(dataset, num_workers=num_workers) lmdb_path = osp.join(dpath, "%s.lmdb" % name) isdir = os.path.isdir(lmdb_path) print("Generate LMDB to %s" % lmdb_path) + map_size = 30737418240 # this should be adjusted based on OS/db size db = lmdb.open(lmdb_path, subdir=isdir, - map_size=1099511627776 * 2, readonly=False, + map_size=map_size, readonly=False, meminit=False, map_async=True) print(len(dataset), len(data_loader)) txn = db.begin(write=True) - for idx, data in enumerate(data_loader): + for idx, (data, label) in enumerate(data_loader): # print(type(data), data) - image, label = data[0] - txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label))) + image = data + label = label.numpy() + txn.put(u'{}'.format(idx).encode('ascii'), dumps_pickle((image, label))) if idx % write_frequency == 0: print("[%d/%d]" % (idx, len(data_loader))) txn.commit() @@ -159,21 +152,20 @@ def folder2lmdb(dpath, name="train", write_frequency=5000, num_workers=16): txn.commit() keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] with db.begin(write=True) as txn: - txn.put(b'__keys__', dumps_pyarrow(keys)) - txn.put(b'__len__', dumps_pyarrow(len(keys))) + txn.put(b'__keys__', dumps_pickle(keys)) + txn.put(b'__len__', dumps_pickle(len(keys))) print("Flushing database ...") db.sync() db.close() - -if __name__ == "__main__": +if __name__=='__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("-f", "--folder", type=str) parser.add_argument('-s', '--split', type=str, default="val") parser.add_argument('--out', type=str, default=".") - parser.add_argument('-p', '--procs', type=int, default=20) + parser.add_argument('-p', '--procs', type=int, default=0) args = parser.parse_args()