Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from PIL import Image
from tqdm import tqdm
from scipy.ndimage.filters import maximum_filter
from scipy.ndimage import maximum_filter
from shapely.geometry import Polygon

import torch
Expand Down Expand Up @@ -76,9 +76,12 @@ def inference(net, x, device, flip=False, rotate=[], visualize=False,
# Network feedforward (with testing augmentation)
x, aug_type = augment(x, flip, rotate)
y_bon_, y_cor_ = net(x.to(device))
output_npy = np.vstack((y_bon_[0], y_cor_[0, : :]))

y_bon_ = augment_undo(y_bon_.cpu(), aug_type).mean(0)
y_cor_ = augment_undo(torch.sigmoid(y_cor_).cpu(), aug_type).mean(0)

# output_npy = np.vstack((y_bon_[0], y_cor_[0, : :]))
# Visualize raw model output
if visualize:
vis_out = visualize_a_data(x[0],
Expand Down Expand Up @@ -138,18 +141,20 @@ def inference(net, x, device, flip=False, rotate=[], visualize=False,
cor_id[:, 0] /= W
cor_id[:, 1] /= H

return cor_id, z0, z1, vis_out
return cor_id, z0, z1, vis_out, output_npy


if __name__ == '__main__':

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--pth', required=True,
help='path to load saved checkpoint.')
parser.add_argument('--img_glob', required=True,
parser.add_argument('--img_glob', required=False,
help='NOTE: Remeber to quote your glob path. '
'All the given images are assumed to be aligned'
'or you should use preporcess.py to do so.')
parser.add_argument('--dataset_dir', required=False,
help='NOTE: Remeber to quote your glob path.')
parser.add_argument('--output_dir', required=True)
parser.add_argument('--visualize', action='store_true')
# Augmentation related
Expand All @@ -171,7 +176,12 @@ def inference(net, x, device, flip=False, rotate=[], visualize=False,
args = parser.parse_args()

# Prepare image to processed
paths = sorted(glob.glob(args.img_glob))
# paths = sorted(glob.glob(args.img_glob))

with open(os.path.join(args.dataset_dir, 'vo_final/keyframe_list.txt'), 'r') as f:
kf_list = f.readlines()
paths = [os.path.join(args.dataset_dir, "rgb/", str(int(kf[:-1])) +".png") for kf in kf_list]

if len(paths) == 0:
print('no images found')
for path in paths:
Expand All @@ -182,11 +192,15 @@ def inference(net, x, device, flip=False, rotate=[], visualize=False,
print('Output directory %s not existed. Create one.' % args.output_dir)
os.makedirs(args.output_dir)
device = torch.device('cpu' if args.no_cuda else 'cuda')
device = torch.device('cpu')

# Loaded trained model
net = utils.load_trained_model(HorizonNet, args.pth).to(device)
net.eval()

npy_dir = os.path.join(os.path.dirname(args.output_dir), "vo_final/npy/")
if not os.path.isdir(npy_dir):
print('Output directory %s not existed. Create one.' % npy_dir)
os.makedirs(npy_dir)
# Inferencing
with torch.no_grad():
for i_path in tqdm(paths, desc='Inferencing'):
Expand All @@ -200,7 +214,7 @@ def inference(net, x, device, flip=False, rotate=[], visualize=False,
x = torch.FloatTensor([img_ori / 255])

# Inferenceing corners
cor_id, z0, z1, vis_out = inference(net=net, x=x, device=device,
cor_id, z0, z1, vis_out, output_npy = inference(net=net, x=x, device=device,
flip=args.flip, rotate=args.rotate,
visualize=args.visualize,
force_cuboid=args.force_cuboid,
Expand All @@ -215,6 +229,10 @@ def inference(net, x, device, flip=False, rotate=[], visualize=False,
'uv': [[float(u), float(v)] for u, v in cor_id],
}, f)

npy_filepath = os.path.join(npy_dir, k + '.npy')
np.save(npy_filepath, output_npy)
print (f'Saving npy file at {npy_filepath}')

if vis_out is not None:
vis_path = os.path.join(args.output_dir, k + '.raw.png')
vh, vw = vis_out.shape[:2]
Expand Down