Skip to content
This repository was archived by the owner on Dec 18, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,5 @@ venv.bak/
*.jpg
*.jpeg
*.pt
*.onnx
.DS_Store
16 changes: 11 additions & 5 deletions dpt/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,17 @@ def forward(self, x):
out_size = torch.Size((h // self.patch_size[1], w // self.patch_size[0]))

if not self.hybrid_backbone:
layer_1 = self.act_postprocess1(layer_1.unflatten(2, out_size))
layer_2 = self.act_postprocess2(layer_2.unflatten(2, out_size))

layer_3 = self.act_postprocess3(layer_3.unflatten(2, out_size))
layer_4 = self.act_postprocess4(layer_4.unflatten(2, out_size))
# according to https://github.com/isl-org/DPT/issues/42#issuecomment-944657114
# layer_1 = self.act_postprocess1(layer_1.unflatten(2, out_size))
# layer_2 = self.act_postprocess2(layer_2.unflatten(2, out_size))
layer_1 = self.act_postprocess1(layer_1.view(layer_1.shape[0], layer_1.shape[1], *out_size))
layer_2 = self.act_postprocess2(layer_2.view(layer_2.shape[0], layer_2.shape[1], *out_size))

# according to https://github.com/isl-org/DPT/issues/42#issuecomment-944657114
# layer_3 = self.act_postprocess3(layer_3.unflatten(2, out_size))
# layer_4 = self.act_postprocess4(layer_4.unflatten(2, out_size))
layer_3 = self.act_postprocess3(layer_3.view(layer_3.shape[0], layer_3.shape[1], *out_size))
layer_4 = self.act_postprocess4(layer_4.view(layer_4.shape[0], layer_4.shape[1], *out_size))

return layer_1, layer_2, layer_3, layer_4

Expand Down
163 changes: 163 additions & 0 deletions export_monodepth_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
import argparse
import onnx
import onnxruntime
import json
import numpy as np
import cv2

from dpt.models import DPTDepthModel
from dpt.midas_net import MidasNet_large
import util.io


def main(model_path, model_type, output_path, batch_size, test_image_path):
# load network
if model_type == "dpt_large": # DPT-Large
net_w = net_h = 384
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
prediction_factor = 1
elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w = net_h = 384
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
prediction_factor = 1
elif model_type == "dpt_hybrid_kitti":
net_w = 1216
net_h = 352

model = DPTDepthModel(
path=model_path,
scale=0.00006016,
shift=0.00579,
invert=True,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)

normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
prediction_factor = 256
elif model_type == "dpt_hybrid_nyu":
net_w = 640
net_h = 480

model = DPTDepthModel(
path=model_path,
scale=0.000305,
shift=0.1378,
invert=True,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)

normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
prediction_factor = 1000.0
elif model_type == "midas_v21": # Convolutional model
net_w = net_h = 384

model = MidasNet_large(model_path, non_negative=True)
normalization = dict(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
prediction_factor = 1
else:
assert (
False
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]"

model.eval()

dummy_input = torch.zeros((batch_size, 3, net_h, net_w))
# TODO: right now, the batch size is not dynamic due to the PyTorch tracer
# treating the batch size as constant (see get_attention() in vit.py).
# Therefore you have to use a batch size of one to use this together with
# run_monodepth_onnx.py.
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=["input"],
output_names=["output"],
opset_version=11,
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

# store normalization configuration
model_onnx = onnx.load(output_path)
meta_imagesize = model_onnx.metadata_props.add()
meta_imagesize.key = "ImageSize"
meta_imagesize.value = json.dumps([net_w, net_h])
meta_normalization = model_onnx.metadata_props.add()
meta_normalization.key = "Normalization"
meta_normalization.value = json.dumps(normalization)
meta_prediction_factor = model_onnx.metadata_props.add()
meta_prediction_factor.key = "PredictionFactor"
meta_prediction_factor.value = str(prediction_factor)
onnx.save(model_onnx, output_path)
del model_onnx

if test_image_path is not None:
# load test image
img = util.io.read_image(test_image_path)

# resize
img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA)

# normalize
img_input = (img_input - np.array(normalization["mean"])) / np.array(normalization["std"])

# transpose from HWC to CHW
img_input = img_input.transpose(2, 0, 1)

# add batch dimension
img_input = np.stack([img_input] * batch_size)

# validate accuracy of exported model
torch_out = model(torch.from_numpy(img_input.astype(np.float32))).detach().cpu().numpy()
session = onnxruntime.InferenceSession(
output_path,
providers=[
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
onnx_out = session.run(["output"], {"input": img_input.astype(np.float32)})[0]

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-02, atol=1e-04)
print("Exported model predictions match original")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_weights", help="path to input model weights")
parser.add_argument("output_path", help="path to output model weights")
parser.add_argument(
"-t",
"--model_type",
default="dpt_hybrid",
help="model type [dpt_large|dpt_hybrid|midas_v21]",
)
parser.add_argument("--batch_size", default=1, help="batch size used for tracing")
parser.add_argument(
"--test_image_path",
type=str,
help="path to some image to test the accuracy of the exported model against the original"
)

args = parser.parse_args()
main(args.model_weights, args.model_type, args.output_path, args.batch_size, args.test_image_path)
114 changes: 114 additions & 0 deletions export_segmentation_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import argparse
import onnx
import onnxruntime
import json
import numpy as np
import cv2

from dpt.models import DPTSegmentationModel
import util.io


def main(model_path, model_type, output_path, batch_size, test_image_path):
net_w = net_h = 480
normalization = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

# load network
if model_type == "dpt_large":
model = DPTSegmentationModel(
150,
path=model_path,
backbone="vitl16_384",
)
elif model_type == "dpt_hybrid":
model = DPTSegmentationModel(
150,
path=model_path,
backbone="vitb_rn50_384",
)
else:
assert (
False
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]"

model.eval()

dummy_input = torch.zeros((batch_size, 3, net_h, net_w))
# TODO: right now, the batch size is not dynamic due to the PyTorch tracer
# treating the batch size as constant (see get_attention() in vit.py).
# Therefore you have to use a batch size of one to use this together with
# run_monodepth_onnx.py.
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=["input"],
output_names=["output"],
opset_version=11,
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

# store normalization configuration
model_onnx = onnx.load(output_path)
meta_imagesize = model_onnx.metadata_props.add()
meta_imagesize.key = "ImageSize"
meta_imagesize.value = json.dumps([net_w, net_h])
meta_normalization = model_onnx.metadata_props.add()
meta_normalization.key = "Normalization"
meta_normalization.value = json.dumps(normalization)
onnx.save(model_onnx, output_path)
del model_onnx

if test_image_path is not None:
# load test image
img = util.io.read_image(test_image_path)

# resize
img_input = cv2.resize(img, (net_h, net_w), cv2.INTER_AREA)

# normalize
img_input = (img_input - np.array(normalization["mean"])) / np.array(normalization["std"])

# transpose from HWC to CHW
img_input = img_input.transpose(2, 0, 1)

# add batch dimension
img_input = np.stack([img_input] * batch_size)

# validate accuracy of exported model
torch_out = model(torch.from_numpy(img_input.astype(np.float32))).detach().cpu().numpy()
session = onnxruntime.InferenceSession(
output_path,
providers=[
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
],
)
onnx_out = session.run(["output"], {"input": img_input.astype(np.float32)})[0]

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(torch_out, onnx_out, rtol=1e-02, atol=1e-04)
print("Exported model predictions match original")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model_weights", help="path to input model weights")
parser.add_argument("output_path", help="path to output model weights")
parser.add_argument(
"-t",
"--model_type",
default="dpt_hybrid",
help="model type [dpt_large|dpt_hybrid]",
)
parser.add_argument("--batch_size", default=1, help="batch size used for tracing")
parser.add_argument(
"--test_image_path",
type=str,
help="path to some image to test the accuracy of the exported model against the original"
)

args = parser.parse_args()
main(args.model_weights, args.model_type, args.output_path, args.batch_size, args.test_image_path)
Loading