Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
104 changes: 5 additions & 99 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,101 +1,7 @@
# Byte-compiled / optimized / DLL files
#python cache
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
img_align_celeba/
model/*/*.pth
!model/*/*epoch_500*.pth
results/
2 changes: 1 addition & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def loadFromFile(path, datasize):
f=open(path)
data=[]
label=[]
for idx in xrange(0, datasize):
for idx in range(0, datasize):
line = f.readline().split()
data.append(line[0])
label.append(line[1])
Expand Down
99 changes: 62 additions & 37 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@
from math import log10
import torchvision
import cv2
from PIL import Image, ImageOps

parser = argparse.ArgumentParser()
parser.add_argument('--test', action='store_true', help='enables test during training')
parser.add_argument('--mse_avg', action='store_true', help='enables mse avg')
parser.add_argument('--num_layers_res', type=int, help='number of the layers in residual block', default=2)
parser.add_argument('--nrow', type=int, help='number of the rows to save images', default=10)
parser.add_argument('--trainfiles', default="path/celeba/train.list", type=str, help='the list of training files')
parser.add_argument('--dataroot', default="path/celeba", type=str, help='path to dataset')
parser.add_argument('--testfiles', default="path/test.list", type=str, help='the list of training files')
parser.add_argument('--testroot', default="path/celeba", type=str, help='path to dataset')
parser.add_argument('--trainsize', type=int, help='number of training data', default=162770)
parser.add_argument('--testsize', type=int, help='number of testing data', default=19962)
parser.add_argument('--trainfiles', default="train.list", type=str, help='the list of training files')
parser.add_argument('--dataroot', default="img_align_celeba", type=str, help='path to dataset')
parser.add_argument('--testfiles', default="test.list", type=str, help='the list of training files')
parser.add_argument('--testroot', default="img_align_celeba", type=str, help='path to dataset')
parser.add_argument('--trainsize', type=int, help='number of training data', default=4000)
parser.add_argument('--testsize', type=int, help='number of testing data', default=500)
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--test_batchSize', type=int, default=64, help='test batch size')
Expand All @@ -56,7 +57,15 @@
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")

# NetStruct_WavetletDepth_batch-testBatch_Change#1_Change#2...
TAG = "removeBNinResBlocks_3_16-16_testAtLastEpoch"

def main():

if not os.path.exists("model/{0}".format(TAG)):
os.makedirs("model/{0}".format(TAG))

f_psrn = open("model/{0}/psnr_{0}.txt".format(TAG,TAG), "a")

global opt, model
opt = parser.parse_args()
Expand Down Expand Up @@ -140,6 +149,7 @@ def main():
crop_height=None, crop_width=None,
is_random_crop=False, is_mirror=False, is_gray=False,
upscale=mag, is_scale_back=is_scale_back)

test_data_loader = torch.utils.data.DataLoader(test_set, batch_size=opt.test_batchSize,
shuffle=False, num_workers=int(opt.workers))

Expand All @@ -153,37 +163,16 @@ def main():
save_checkpoint(srnet, epoch, 0, 'sr_')

for iteration, batch in enumerate(train_data_loader, 0):
#--------------test-------------
if iteration % opt.test_iter is 0 and opt.test:
srnet.eval()
avg_psnr = 0
for titer, batch in enumerate(test_data_loader,0):
input, target = Variable(batch[0]), Variable(batch[1])
if opt.cuda:
input = input.cuda()
target = target.cuda()

wavelets = forward_parallel(srnet, input, opt.ngpu)
prediction = wavelet_rec(wavelets)
mse = criterion_m(prediction, target)
psnr = 10 * log10(1 / (mse.data[0]) )
avg_psnr += psnr

save_images(prediction, "Epoch_{:03d}_Iter_{:06d}_{:02d}_o.jpg".format(epoch, iteration, titer),
path=opt.outf, nrow=opt.nrow)


print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(test_data_loader)))
srnet.train()


#--------------train------------
input, target = Variable(batch[0]), Variable(batch[1], requires_grad=False)
if opt.cuda:
input = input.cuda()
target = target.cuda()

target_wavelets = wavelet_dec(target)

batch_size = target.size(0)
wavelets_lr = target_wavelets[:,0:3,:,:]
wavelets_sr = target_wavelets[:,3:,:,:]
Expand All @@ -204,11 +193,43 @@ def main():
optimizer_sr.step()

info = "===> Epoch[{}]({}/{}): time: {:4.4f}:".format(epoch, iteration, len(train_data_loader), time.time()-start_time)
info += "Rec: {:.4f}, {:.4f}, {:.4f}, Texture: {:.4f}".format(loss_lr.data[0], loss_sr.data[0],
loss_img.data[0], loss_textures.data[0])
info += "Rec: {:.4f}, {:.4f}, {:.4f}, Texture: {:.4f}".format(loss_lr.item(), loss_sr.item(),
loss_img.item(), loss_textures.item())

print(info)




#--------------test-------------
#if iteration % opt.test_iter is 0 and opt.test:
if iteration == len(train_data_loader) - 1 and opt.test:
srnet.eval()
avg_psnr = 0
for titer, batch in enumerate(test_data_loader,0):
input, target = Variable(batch[0]), Variable(batch[1])
if opt.cuda:
input = input.cuda()
target = target.cuda()

wavelets = forward_parallel(srnet, input, opt.ngpu)
prediction = wavelet_rec(wavelets)
mse = criterion_m(prediction, target)
psnr = 10 * log10(1 / (mse.item()) )
avg_psnr += psnr

save_images(prediction, "Epoch_{:03d}_Iter_{:06d}_{:02d}_o.jpg".format(epoch, iteration, titer),
path=opt.outf, nrow=opt.nrow)

if epoch%50 == 0:
save_images(prediction, "Epoch_{:03d}_Iter_{:06d}_{:02d}_o.jpg".format(epoch, iteration, titer),
path="result_step50/{0}/".format(TAG), nrow=opt.nrow)


print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(test_data_loader)))
f_psrn.write("{:.4f}\n".format(avg_psnr / len(test_data_loader)))
srnet.train()

f_psrn.close()

def forward_parallel(net, input, ngpu):
if ngpu > 1:
Expand All @@ -217,16 +238,20 @@ def forward_parallel(net, input, ngpu):
return net(input)

def save_checkpoint(model, epoch, iteration, prefix=""):
model_out_path = "model/" + prefix +"model_epoch_{}_iter_{}.pth".format(epoch, iteration)
model_out_path = "model/{0}/".format(TAG) + prefix +"model_epoch_{}_iter_{}.pth".format(epoch, iteration)
state = {"epoch": epoch ,"model": model}
if not os.path.exists("model/"):
os.makedirs("model/")
if not os.path.exists("model/{0}/".format(TAG)):
os.makedirs("model/{0}/".format(TAG))

torch.save(state, model_out_path)

print("Checkpoint saved to {}".format(model_out_path))

def save_images(images, name, path, nrow=10):

# Create target Directory if don't exist
if not os.path.exists(path):
os.makedirs(path)
#print(images.size())
img = images.cpu()
im = img.data.numpy().astype(np.float32)
Expand Down
Loading