-
Notifications
You must be signed in to change notification settings - Fork 47
Expand file tree
/
Copy pathTensorRT_InferenceEngine.cpp
More file actions
497 lines (424 loc) · 16.6 KB
/
TensorRT_InferenceEngine.cpp
File metadata and controls
497 lines (424 loc) · 16.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
#include "TRT_InferenceEngine/TensorRT_InferenceEngine.h"
#include <NvOnnxParser.h>
namespace
{
std::string get_devicename_from_deviceid(int device_id)
{
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device_id);
auto device_name = std::string(prop.name);
// Remove spaces from device name
device_name.erase(
std::remove_if(device_name.begin(), device_name.end(), ::isspace),
device_name.end());
return device_name;
}
}// namespace
inference_backend::TensorRTInferenceEngine::TensorRTInferenceEngine(
TRTOptimizerParams &optimization_params, u_int8_t logging_level)
{
_set_optimization_params(optimization_params);
_init_TRT_logger(logging_level);
}
inference_backend::TensorRTInferenceEngine::~TensorRTInferenceEngine()
{
for (auto &buffer: _buffers)
cudaFree(buffer);
cudaStreamDestroy(_cuda_stream);
}
void inference_backend::TensorRTInferenceEngine::_set_optimization_params(
const TRTOptimizerParams ¶ms)
{
_optimization_params = params;
}
void inference_backend::TensorRTInferenceEngine::_init_TRT_logger(
u_int8_t logging_level)
{
_logger = std::make_unique<TRTLogger>(
static_cast<nvinfer1::ILogger::Severity>(logging_level));
}
bool inference_backend::TensorRTInferenceEngine::load_model(
const std::string &onnx_model_path)
{
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Loading ONNX model from path: ")
.append(onnx_model_path)
.c_str());
// Check if ONNX model exists
if (!file_exists(onnx_model_path))
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("ONNX model not found at path: ")
.append(onnx_model_path)
.c_str());
return false;
}
// Ensure gpu_id is valid
int numGPUs{0};
cudaGetDeviceCount(&numGPUs);
if (_optimization_params.gpu_id >= numGPUs)
{
int numGPUs{0};
cudaGetDeviceCount(&numGPUs);
_logger->log(nvinfer1::ILogger::Severity::kERROR,
("Unable to set GPU device index to: " +
std::to_string(_optimization_params.gpu_id) +
". Note, your device has " + std::to_string(numGPUs) +
" CUDA-capable GPU(s).")
.c_str());
}
// Check if engine exists, if not build it
std::string engine_path = get_engine_path(onnx_model_path);
if (!file_exists(engine_path))
{
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Engine not found at path: ")
.append(engine_path)
.c_str());
_build_engine(onnx_model_path);
}
// Set device index
if (auto ret = cudaSetDevice(_optimization_params.gpu_id); ret != 0)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
("Failed to set device index to: " +
std::to_string(_optimization_params.gpu_id))
.c_str());
return false;
}
// Deserialize engine
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Deserializing engine from path: ")
.append(engine_path)
.c_str());
if (_deserialize_engine(engine_path))
{
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Engine deserialized successfully. Allocating "
"buffers..")
.c_str());
_allocate_buffers();
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Engine loaded successfully").c_str());
return true;
}
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to load engine").c_str());
return false;
}
std::string inference_backend::TensorRTInferenceEngine::get_engine_path(
const std::string &onnx_model_path) const
{
// Parent director + model name
std::string engine_path =
boost::filesystem::path(onnx_model_path).parent_path().string() +
"/" + boost::filesystem::path(onnx_model_path).stem().string();
// Hostname
char hostname[1024];
gethostname(hostname, sizeof(hostname));
std::string suffix(hostname);
suffix.append("_GPU_" +
get_devicename_from_deviceid(_optimization_params.gpu_id));
// TensorRT version
suffix.append("_TRT" + std::to_string(NV_TENSORRT_VERSION));
// CUDA version
suffix.append("_CUDA" + std::to_string(CUDART_VERSION));
// Batch size
suffix.append("_" + std::to_string(_optimization_params.batch_size));
// int8, tf32, fp16, fp32
if (_optimization_params.int8)
suffix.append("_INT8");
else
{
if (_optimization_params.tf32)
suffix.append("_TF32");
_optimization_params.fp16 ? suffix.append("_FP16")
: suffix.append("_FP32");
}
// Engine path = parent_dir/model_name_hostname_GPU_device_name_TRT_version_CUDA_version_batch_size_int8_fp16_fp32.engine
engine_path.append("_" + suffix + ".engine");
return engine_path;
}
bool inference_backend::TensorRTInferenceEngine::file_exists(
const std::string &name) const
{
return boost::filesystem::exists(name);
}
size_t inference_backend::TensorRTInferenceEngine::get_size_by_dims(
const nvinfer1::Dims &dims, int element_size) const
{
size_t size = 1;
for (size_t i = 0; i < dims.nbDims; ++i)
size *= (dims.d[i] ? dims.d[i] : 1);
return size * element_size;
}
void inference_backend::TensorRTInferenceEngine::_allocate_buffers()
{
for (void *&buffer: _buffers)
cudaFree(buffer);
_buffers.clear();
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
_buffers = std::vector<void *>(_engine->getNbBindings());
#else
_buffers = std::vector<void *>(_engine->getNbIOTensors());
#endif
size_t output_idx = 0;
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
for (size_t i = 0; i < _engine->getNbBindings(); ++i)
#else
for (size_t i = 0; i < _engine->getNbIOTensors(); ++i)
#endif
{
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
nvinfer1::Dims dims = _engine->getBindingDimensions(i);
#else
const char *name = _engine->getIOTensorName(i);
nvinfer1::Dims dims = _engine->getTensorShape(name);
#endif
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
nvinfer1::DataType dtype = _engine->getBindingDataType(i);
#else
nvinfer1::DataType dtype = _engine->getTensorDataType(name);
#endif
size_t total_size = get_size_by_dims(dims, sizeof(float));
cudaMalloc(&_buffers[i], total_size);
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
if (_engine->getBindingName(i) == _optimization_params.input_layer_name)
#else
if (std::string(name) == _optimization_params.input_layer_name)
#endif
{
_input_dims.emplace_back(dims);
_input_idx = i;
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Found input layer with name: ")
.append(_engine->getBindingName(i))
.c_str());
#else
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Found input layer with name: ")
.append(name)
.c_str());
#endif
}
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
else if (std::find(_optimization_params.output_layer_names.begin(),
_optimization_params.output_layer_names.end(),
_engine->getBindingName(i)) !=
_optimization_params.output_layer_names.end())
#else
else if (std::find(_optimization_params.output_layer_names.begin(),
_optimization_params.output_layer_names.end(),
name) !=
_optimization_params.output_layer_names.end())
#endif
{
_output_dims.emplace_back(dims);
_output_idx.emplace_back(i);
++output_idx;
#if NVINFER_MAJOR == 8 && NVINFER_MINOR <= 5
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Found output layer with name: ")
.append(_engine->getBindingName(i))
.c_str());
#else
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Found output layer with name: ")
.append(name)
.c_str());
#endif
}
}
if (_input_dims.empty())
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Input layer not found").c_str());
return;
}
if (_output_dims.empty())
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Output layer not found").c_str());
return;
}
}
bool inference_backend::TensorRTInferenceEngine::_deserialize_engine(
const std::string &engine_path)
{
// Reference: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#perform_inference_c
std::ifstream engine_file(engine_path, std::ios::binary);
if (!engine_file)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to open engine file").c_str());
return false;
}
// Read engine file
std::vector<char> trt_model_stream;
size_t size{0};
engine_file.seekg(0, std::ios::end);
size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
trt_model_stream.resize(size);
engine_file.read(trt_model_stream.data(), size);
engine_file.close();
// Deserialize engine
// Runtime must outlive the engine. Keep it as a member
_runtime = makeUnique(nvinfer1::createInferRuntime(*_logger));
_engine = makeUnique(_runtime->deserializeCudaEngine(
trt_model_stream.data(), trt_model_stream.size()));
if (!_engine)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to deserialize engine").c_str());
return false;
}
// Create execution context
_context = makeUnique(_engine->createExecutionContext());
if (!_context)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to create execution context").c_str());
return false;
}
return true;
}
void inference_backend::TensorRTInferenceEngine::_build_engine(
const std::string &onnx_model_path)
{
// Reference: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#c_topics
// Network builder
auto builder = std::unique_ptr<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(*_logger));
// Network definition
uint32_t flag =
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(
builder->createNetworkV2(flag));
// ONNX parser
auto parser = std::unique_ptr<nvonnxparser::IParser>(
nvonnxparser::createParser(*network, *_logger));
// Parse ONNX model
int verbosity = static_cast<int>(_logSeverity);
parser->parseFromFile(onnx_model_path.c_str(), verbosity);
for (int32_t i = 0; i < parser->getNbErrors(); ++i)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
parser->getError(i)->desc());
}
// Optimization profile
// TODO: Check more about optimization profiles
auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(
builder->createBuilderConfig());
auto profile = builder->createOptimizationProfile();
profile->setDimensions(network->getInput(0)->getName(),
nvinfer1::OptProfileSelector::kMIN,
network->getInput(0)->getDimensions());
profile->setDimensions(network->getInput(0)->getName(),
nvinfer1::OptProfileSelector::kOPT,
network->getInput(0)->getDimensions());
profile->setDimensions(network->getInput(0)->getName(),
nvinfer1::OptProfileSelector::kMAX,
network->getInput(0)->getDimensions());
config->addOptimizationProfile(profile);
if (_optimization_params.int8)
{
// TODO: Add int8 calibration
_logger->log(nvinfer1::ILogger::Severity::kWARNING,
std::string("INT8 calibration is not supported yet. "
"Switching to FP16 or FP32 calibration")
.c_str());
}
else
{
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("FP16 or FP32 calibration").c_str());
config->setFlag(nvinfer1::BuilderFlag::kFP16);
if (_optimization_params.tf32)
config->setFlag(nvinfer1::BuilderFlag::kTF32);
}
// Build engine
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Building engine").c_str());
std::unique_ptr<nvinfer1::IHostMemory> engine_plan{
builder->buildSerializedNetwork(*network, *config)};
std::unique_ptr<nvinfer1::IRuntime> runtime{
nvinfer1::createInferRuntime(*_logger)};
std::shared_ptr<nvinfer1::ICudaEngine> engine =
std::shared_ptr<nvinfer1::ICudaEngine>(
runtime->deserializeCudaEngine(engine_plan->data(),
engine_plan->size()),
TRTDeleter());
if (!engine)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to build engine").c_str());
return;
}
// Serialize engine
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Serializing engine").c_str());
std::string engine_path = get_engine_path(onnx_model_path);
std::ofstream engine_file(engine_path, std::ios::binary);
if (!engine_file)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to open engine file").c_str());
return;
}
std::unique_ptr<nvinfer1::IHostMemory> serialized_engine{
engine->serialize()};
if (!serialized_engine)
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to serialize engine").c_str());
return;
}
engine_file.write(reinterpret_cast<const char *>(serialized_engine->data()),
serialized_engine->size());
if (engine_file.fail())
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Failed to write engine file").c_str());
return;
}
_logger->log(nvinfer1::ILogger::Severity::kINFO,
std::string("Engine serialized successfully").c_str());
engine_file.close();
}
inference_backend::ModelPredictions
inference_backend::TensorRTInferenceEngine::forward(const cv::Mat &input_image)
{
// Reference: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#perform-inference
if (!_buffers.size())
{
_logger->log(nvinfer1::ILogger::Severity::kERROR,
std::string("Buffers not allocated").c_str());
return ModelPredictions();
}
// Create a blob from the image
cv::Mat image_blob;
cv::dnn::blobFromImage(input_image, image_blob, 1.0,
cv::Size(_input_dims[0].d[3], _input_dims[0].d[2]),
cv::Scalar(), false, false, CV_32F);
// Ensure image blob size matches input layer size
assert(image_blob.total() == get_size_by_dims(_input_dims[0]));
// Copy image blob to CUDA input buffer
cudaMemcpyAsync(_buffers[_input_idx], image_blob.data,
get_size_by_dims(_input_dims[0], sizeof(float)),
cudaMemcpyHostToDevice);
// Run inference
_context->executeV2(_buffers.data());
// Copy CUDA output buffer to host
ModelPredictions predictions;
for (size_t i = 0; i < _output_idx.size(); ++i)
{
std::vector<float> output(get_size_by_dims(_output_dims[i]));
cudaMemcpyAsync(output.data(), _buffers[_output_idx[i]],
output.size() * sizeof(float), cudaMemcpyDeviceToHost);
predictions.emplace_back(output);
}
return predictions;
}