-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrvm_perf.html
More file actions
683 lines (584 loc) · 29.1 KB
/
rvm_perf.html
File metadata and controls
683 lines (584 loc) · 29.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Coding Blog - Mrutunjayya</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="stylesheet" href="https://unpkg.com/prismjs/themes/prism.css">
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
background: #f4f4f4;
line-height: 1.6;
}
header {
background: #2c3e50;
color: white;
padding: 1rem;
text-align: center;
}
nav {
background: #34495e;
padding: 1rem;
display: flex;
justify-content: center;
gap: 1.5rem;
}
nav a {
color: white;
text-decoration: none;
font-weight: bold;
}
nav a:hover {
text-decoration: underline;
}
.container {
max-width: 800px;
margin: 2rem auto;
background: white;
padding: 2rem;
box-shadow: 0 0 10px rgba(0,0,0,0.1);
}
pre {
background: #f0f0f0;
padding: 1rem;
overflow-x: auto;
border-left: 5px solid #007acc;
}
footer {
text-align: center;
padding: 1rem;
background: #eee;
margin-top: 2rem;
}
h1, h2, h3, h4, h5 {
margin-bottom: 0.2rem;
}
p {
text-align: justify;
margin-top: 0.2rem;
}
img {
width: 100%; /* Take full width of content box */
display: block;
margin: 1rem 0; /* Space above/below */
border-radius: 8px; /* Optional styling */
}
pre {
position: relative;
margin-top: 0.0rem;
margin: 0.rem 0;
padding: 0;
background-color: #f7f4f2;
border-left: 2px solid #007acc;
border-radius: 2px;
overflow: auto;
}
pre code {
display: block;
padding: 0.0rem 0.0rem;
font-family: 'Courier New', monospace;
font-size: 0.95rem;
line-height: 1.4;
white-space: pre;
}
.copy-btn {
position: absolute;
top: 6px;
right: 10px;
padding: 3px 8px;
font-size: 0.75rem;
border: none;
background-color: #007acc;
color: white;
border-radius: 4px;
cursor: pointer;
opacity: 0.85;
transition: opacity 0.2s;
}
figcaption {
font-family: 'Arial', sans-serif;
font-size: 14px;
color: #333;
text-align: center;
}
.copy-btn:hover {
opacity: 1;
}
@media (max-width: 600px) {
.container {
padding: 1rem;
}
}
</style>
</head>
<body>
<nav>
<a href="index.html">Home</a>
</nav>
<div class="container">
<h2>Achieve 15-20x perfromance improvement for vision/perception model inference.</h2>
<h4>Introduction:</h4>
<p>Developers aim to run vision perception models, such as YOLO and segmentation algorithms, in real-time. The performance (runtime) of any AI model is influenced by its size and precision. To enhance runtime performance, AI model developers invest time in optimizing both the model's size and architecture, as well as its precision. However, there is a limit to how much one can reduce model size and precision without sacrificing its quality.</p>
<p>
Often overlooked are the systems or methods surrounding model inference. This includes tasks like converting input images or frames to tensor format (from HWC to CHW), normalizing values to a float32 range of 0-1, or transferring data between the CPU and GPU for inference. These pre-processing and post-processing steps, as well as data movement, can create bottlenecks that increase the total inference runtime. In this blog, we will explore ways to optimize these steps to achieve faster inference for vision perception models.
</p>
<p>
Here are the naive steps involved in running inference for a vision perception model:
</p>
<ul>
<li>Load the model and allocate the necessary buffers for input and output.</li>
<li>Decode the image or video file.</li>
<li>Normalize the data to a range of 0 to 1.</li>
<li>Convert the data to tensor format (i.e., from HWC to CHW).</li>
<li>Transfer the tensor to the GPU.</li>
<li>Execute the inference process.</li>
<li>Copy the results back to the CPU.</li>
<li>Convert the tensor back to UINT8 format (i.e., from CHW to HWC).</li>
</ul>
<b>Note: The end of the section "step-0" outlines the time required for each step in the naive inference approach.</b>
<p>The following figure summarizes the performance improvements that will be discussed in this blog post.</p>
<figure>
<img src="fps_table.png">
<figcaption>Figure: Performance improvement from optimizing the each step. *Numbers are averaged over multiple runs</figcaption>
</figure>
<p>Complete code for the blog can be found here: <a href="https://github.com/mjayw2014/rvm_perf_inference">link</a> </p>
<h4>Step-0: Basic setup and Naive Inference [CPU Decode, CPU Pre & Post Processing]:</h4>
<h5><u>Enviroment Setup:</u></h5>
For our experiments we will be using system following configurations:
<p>Create a python virtual environment “seg_env”:</p>
<pre><code class="language-bash">
python -m venv seg_env
seg_env\Scripts\activate
</code></pre>
<p>Install depencies</p>
<pre><code class="language-bash">
pip install cuda-python
pip install opencv-python
pip install tensorrt
pip install PyNvVideoCodec
</code></pre>
<h5><u>Inference:</u></h5>
<p>TensorRT setup: working with NVIDIA TensorRT based inference, there are two critical steps to handle after exporting your optimized model: first, loading the engine, and second, allocating GPU memory buffers for input and output tensors.
</p>
<pre><code class="language-python">
def load_engine(model_file):
assert os.path.exists(model_file)
trt_runtime = trt.Runtime(trt_logger)
with open(model_file, "rb") as f:
return trt_runtime.deserialize_cuda_engine(f.read())
</code></pre>
<p>The function `load_engine(model_file)` is responsible for deserializing a TensorRT engine. It loads a serialized TensorRT engine from disk into memory.
The process begins by creating a TensorRT Runtime object, which is necessary for deserializing the engine. After that,
the function utilizes `deserialize_cuda_engine()` to convert the raw byte data back into a usable engine object in memory.
TensorRT engines are pre-optimized and specifically tailored binary blobs for hardware. By loading the engine in this way, we can
bypass the model parsing and building stage at runtime, allowing for faster inference..</p>
<pre><code class="language-python">
def allocate_buffers(trt_engine):
inputs = {}
outputs = {}
bindings = []
for i in range(trt_engine.num_io_tensors):
tensor_name = trt_engine.get_tensor_name(i)
shape = trt_engine.get_tensor_shape(tensor_name)
dtype = trt_engine.get_tensor_dtype(tensor_name)
size = np.prod(shape) * dtype.itemsize
#allocate tensor on GPU
d_memory = checkCudaErrors(cuda.cuMemAlloc(size))
bindings.append(d_memory)
if trt_engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
inputs[tensor_name] = {'d_mem':d_memory, 'shape': shape, 'size': size, 'dtype': dtype}
else:
outputs[tensor_name] = {'d_mem': d_memory, 'shape': shape, 'size': size, 'dtype': dtype}
return inputs, outputs, bindings
</code></pre>
<p text-align: justify;>The function `allocate_buffers(trt_engine, trt_context)` is responsible for setting up memory for inference. Once the engine is loaded,
it prepares GPU memory buffers for both inputs and outputs. The "inputs" and "outputs" are dictionaries that store GPU memory pointers along with metadata,
which includes the shape, data type, and size of each tensor. The "bindings" list contains device memory pointers in the specific order that TensorRT expects
during inference. To perform all GPU memory allocation and operations, I utilized the "cuda-python" library. Additionally, `checkCudaErrors()` is a helper function
that validates the success of CUDA API calls.</p>
<p> Lets write inference function that executes sequence of steps complete the inference. The funcation does the following:</p>
<pre><code class="language-python">
def inference(inputs, outputs, bindings, trt_context, input_video_file):
cap = cv2.VideoCapture(input_video_file)
frame_count = 0
out_mask = None
start_time = time.time()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if 0 == frame_count:
for i in range(4):
tmp_zeros = np.zeros((inputs['r' + str(i+1) + 'i']['shape']), dtype=np.float32)
#copy the 'ri' tensors to GPU
checkCudaErrors(cuda.cuMemcpyHtoD(inputs['r' + str(i+1) + 'i']['d_mem'], np.ascontiguousarray(tmp_zeros), inputs['r' + str(i+1) + 'i']['size']))
frame_h, frame_w, frame_c = frame.shape
out_mask = np.empty((1, frame_h, frame_w)).astype(np.float32)
# convert uint8 HWC frame to fp32 CHW tensor
frame = frame.astype(np.float32) / 255.0
frame = np.transpose(frame, (2, 0, 1))[np.newaxis, :]
# Copy frame tensor to GPU
checkCudaErrors(cuda.cuMemcpyHtoD(inputs['src']['d_mem'], np.ascontiguousarray(frame), inputs['src']['size']))
# Execute TensorRT in engine
trt_context.execute_v2(bindings)
# Copy result frame (mask) to host
checkCudaErrors(cuda.cuMemcpyDtoH(np.ascontiguousarray(out_mask), outputs['pha']['d_mem'], outputs['pha']['size']))
# Apply mask to original frame and convert back to HWC
frame = frame * out_mask
frame = np.squeeze(frame, axis=0)
frame = np.transpose(frame, (1, 2, 0))
# cv2.imshow("Frame", frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
frame_count += 1
end_time = time.time()
total_time = end_time - start_time
fps = frame_count / total_time
print(f"Total Frame: {frame_count}, Sync FPS: {fps}")
def main():
trt_engine = load_engine("rvm_mobilenetv3_fp32_sim_1080p_modified.trt")
trt_context = trt_engine.create_execution_context()
inputs, outputs, bindings = allocate_buffers(trt_engine)
input_video_file = "codylexi.mp4"
inference(inputs, outputs, bindings, trt_context, input_video_file)
main()
</code></pre>
<p>The following Table 1 provides a breakdown of the time taken for each step. The majority of the time is spent on pre-processing, post-processing, and the transfer of data between the CPU and GPU. Running the asynchronous mode of inference does not result in significant improvements because the time taken for pre-processing and post-processing far exceeds the inference time.</p>
<figure>
<img src="step_0.png" width="300" height="400">
<figcaption>Table.1: Running time of each step.</figcaption>
</figure>
<h4>Step-1:CPU Decode, GPU pre & post processing:</h4>
<p>To offload pre-processing and post-processing steps to the GPU, we need to write custom CUDA kernels.
The following two kernels, "pre_process" and "post_process," perform the respective operations on the input data.
These are naïve implementations; further optimization can be done (which is not the scope of this article).</P>
<p>The "pre_process" kernel takes a UINT8 (0-255) image array in HWC (Height, Width, Channels) format as input and converts it
into tensor format CHW (Channels, Height, Width) of float32 (normalized between 0-1).</P>
<pre><code class="language-cpp">
// CUDA Kernel for PRE PROCESSING
// uint8 -> NHWC -> NCHW -> fp32
extern "C" __global__ void pre_process(unsigned char* input, float* output, int stride)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < stride)
{
output[0 * stride + tid] = input[tid * 3 + 0] / 255.0f;
output[1 * stride + tid] = input[tid * 3 + 1] / 255.0f;
output[2 * stride + tid] = input[tid * 3 + 2] / 255.0f;
}
}
</code></pre>
<P>Finally, the "post_process" kernel applies the mask, "pha" tensor, generated by the model to the "frg" (foreground)
tensor produced by the model and converts the float32 tensor back to UINT8 in HWC format. </P>
<pre><code class="language-cpp">
//CUDA Kernel for POST PROCESSING
//1. Apply mask
//2. NCHW->NHWC->uint8
extern "C" __global__ void post_process(float* fgr, float* pha, char* output, int stride)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if(tid < stride)
{
fgr[0 * stride + tid] = fgr[0 * stride + tid] * pha[0 * stride + tid];
fgr[1 * stride + tid] = fgr[1 * stride + tid] * pha[0 * stride + tid];
fgr[2 * stride + tid] = fgr[2 * stride + tid] * pha[0 * stride + tid];
}
__syncthreads();
for(size_t c = 0; c != 3; ++c)
{
output[tid * 3 + c] = (unsigned char)(fgr[c * stride + tid] * 255);
}
}
</code></pre>
<p>
Next, we set up the host-side code by allocating GPU memory for input and output frames.
This is the next step after calling the “allocate_buffers()” function from Step 0.
We use the “CUDA-Python” library [link]. I’ve adapted the “checkCudaErrors()” [link]
helper function from the official CUDA-Python library. This function returns an error code if one occurs; otherwise,
it returns the corresponding value from the function. In the usage below, it returns the GPU memory addresses for the
“input” and “output” tensors.
</p>
<pre><code class="language-python">
_, frame_channel, frame_height, frame_width = inputs['src']['shape']
d_frame_size = frame_channel * frame_height * frame_width * np.dtype(np.uint8).itemsize
d_in_frame = checkCudaErrors(cuda.cuMemAlloc(d_frame_size))
d_out_frame = checkCudaErrors(cuda.cuMemAlloc(d_frame_size))
</code></pre>
<p>In ‘CUDA-Python,’ we cannot directly compile and execute the CUDA kernel from a “.cu” file as we can in C/C++.
Instead, we need to either read the “.cu” file as a string or write the CUDA kernel as a string in Python.
For more information on how to format kernels as strings, refer to the GitHub [link] for this blog,
specifically about “cuda_pre_post_kernels.”</p>
<p>The following code snippet demonstrates how to read the string format of CUDA kernels and load the “pre_process” and
“post_process” kernels. I have utilized the “KernelHelper” function from the “CUDA-Python” library [link].
This function reads the string representations of the pre-processing and post-processing kernels and compiles them.
For more details, refer to the helper function in the repository [link]. </p>
<pre><code class="language-python">
cudaKernelHandle = KernelHelper(cuda_pre_post_kernels, int(cudaDevice))
cuda_pre_process_kernel = cudaKernelHandle.getFunction(b'pre_process')
cuda_post_process_kernel = cudaKernelHandle.getFunction(b'post_process')
</code></pre>
<p>Now we set the parameters of the kernel and configuration (Block and Grid Dimensions). </p>
<pre><code class="language-python">
kernel_args_pre_process = ((d_in_frame, inputs['src']['d_mem'], np.int32(frame_height * frame_width), np.int32(0)),
(None, None, ctypes.c_int, ctypes.c_int))
kernel_args_post_process = ((outputs['fgr']['d_mem'], outputs['pha']['d_mem'], d_out_frame, np.int32(frame_height * frame_width)),
(None, None, None, ctypes.c_int))
#Initialize pre and post processing CUDA kernels
cuda_block_dim = (1024,1,1)
cuda_grid_dim = ((frame_height * frame_width) // 1024, 1, 1)
</code></pre>
<p>Copy the UINT8 frame to GPU ("d_in_frame")</p>
<pre><code class="language-python">
#Copy frame to GPU
checkCudaErrors(cuda.cuMemcpyHtoD(d_in_frame, frame, d_frame_size))
</code></pre>
<p> Launch ‘pre_process’ kernel to convert UIN8 HWC to Float32 CHW. Store result into the TensorRT input binding (GPU Memory location) “inputs['src']['d_mem']”. </p>
<pre><code class="language-python">
#Launch pre-preocessing kernel
checkCudaErrors(cuda.cuLaunchKernel(cuda_pre_process_kernel,
cuda_grid_dim[0], cuda_grid_dim[1], cuda_grid_dim[2],
cuda_block_dim[0], cuda_block_dim[1], cuda_block_dim[2],
0, 0,
kernel_args_pre_process, 0))
</code></pre>
<p>Run the inference with TensorRT bindings set in "allocate_buffers" funcation</p>
<pre><code class="language-python">
#Launch TensortRT Inference
trt_context.execute_v2(bindings)
</code></pre>
<p>After inference, launch the post-process kernel. This kernel applies the binary mask to resultinf foreground image and convet Float32 CHW tensor to HWC UINT8 frame. After the post-process copy the results to host </p>
<pre><code class="language-python">
#Launch post-process kernel: Apply mask and convert uint8 image
checkCudaErrors(cuda.cuLaunchKernel(cuda_post_process_kernel,
cuda_grid_dim[0], cuda_grid_dim[1], cuda_grid_dim[2],
cuda_block_dim[0], cuda_block_dim[1], cuda_block_dim[2],
0, 0,
kernel_args_post_process, 0))
#Copy frame to CPU
checkCudaErrors(cuda.cuMemcpyDtoH(frame, d_out_frame, d_frame_size))
</code></pre>
<p> The overall inference with offload of pre and post processing to GPU: ~100FPS. However we still have the overhead of CPU decoding and sending data to GPU. Table 2: Time taken by each step </p>
<figure>
<img src="step_1.png" width="300" height="400">
<figcaption>Table.2: Running time of each step.</figcaption>
</figure>
<h4>Step-2:GPU Decode, GPU Pre & Post Processing:</h4>
<p>In the previous step, we examined the performance after offloading pre- and post-processing tasks to the GPU. However, we still need to decode the video frames on the CPU, copy them to the GPU, and convert their format. What if we offloaded video decoding to the GPU as well?</p>
<p>To decode video on the GPU, we can use PyNvVideoCodec, which is NVIDIA's Python-based library providing APIs for hardware-accelerated video encoding and decoding on NVIDIA GPUs. The class `nvc.SimpleDecoder` accepts a video file as input and delivers the decoded frames. The following code demonstrates this process. </p>
<p>It's important to note two key parameters: "use_device_memory=True" and "output_color_type=nvc.OutputColorType.RGBP":</p>
<ul>
<li>`use_device_memory=True`: This parameter keeps the decoded frames in GPU memory.</li>
<li>`output_color_type=nvc.OutputColorType.RGBP`: This setting converts the decoded frames to CHW (tensor format). By doing this, we eliminate the need for format conversion in our "pre_process" CUDA kernel; we only need to normalize the data to a range of 0 to 1.</li>
</ul>
<pre><code class="language-python">
import PyNvVideoCodec as nvc
decoder = nvc.SimpleDecoder(enc_file_path=input_video_file,
gpu_id=cudaDevice,
use_device_memory=True,
need_scanned_stream_metadata=True,
output_color_type = nvc.OutputColorType.RGBP)
</code></pre>
<p>We allocate one less buffer on GPU and no manual copy of data from CPU to GPU</p>
<pre><code class="language-python">
_, frame_channel, frame_height, frame_width = inputs['src']['shape']
d_frame_size = frame_channel * frame_height * frame_width * np.dtype(np.uint8).itemsize
d_out_frame = checkCudaErrors(cuda.cuMemAlloc(d_frame_size))
h_out_frame = np.empty((frame_height, frame_width, frame_channel), dtype=np.uint8)
#Load the modified CUDA kernel for normalizing decoded frame to 0-1.
cuda_pre_process_kernel = cudaKernelHandle.getFunction(b'pre_process_2')
</code></pre>
<p>However we need to pass the reference of the decoded frame. Following code passes pointer to decoded frame as parameter to pre-process kernel. "cuda.CUdeviceptr" returns GPU memory address of the frame.</p>
<pre><code class="language-python">
for frame in decoder:
frame_plane_ptr = frame.GetPtrToPlane(0)
kernel_args_pre_process = ((cuda.CUdeviceptr(frame_plane_ptr), inputs['src']['d_mem'], np.int32(frame_height * frame_width)),
(None, None, ctypes.c_int))
</code></pre>
<p>Following is Modified CUDA kernel, performing 0-1 normalization </P>
<pre><code class="language-cpp">
extern "C" __global__ void pre_process_2(unsigned char* input, float* output, int stride)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if(tid ;lt stride)
{
output[0 * stride + tid] = input[0 * stride + tid] / 255.0f;
output[1 * stride + tid] = input[1 * stride + tid] / 255.0f;
output[2 * stride + tid] = input[2 * stride + tid] / 255.0f;
}
__syncthreads();
}
</code></pre>
<p>Following shows the speedup achieved by offloading video decoding to GPU.</p>
<figure>
<img src="step_2.png" width="300" height="400">
<figcaption>Table.3: Running time of each step.</figcaption>
</figure>
<h4>Step-3:[FP16] GPU Decode, GPU Pre & Post Processing:</h4>
<p>We have achieved a significant performance improvement by offloading CPU tasks to the GPU. Further speed gains can be realized by using a quantized model. In this next phase, we will experiment with an FP16 model. Please refer to the blog for instructions on how to convert the ONNX model to FP16. Our experimental model accepts FP16 inputs, so we need to modify the preprocessing and postprocessing kernels accordingly. Below is the CUDA kernel used to convert UINT8 to FP16. We utilize the CUDA API function `__float2half` for this conversion.</p>
<pre><code class="language-cpp">
//PRE PROCESSING
//uint8->fp16
extern "C" __global__ void pre_process_fp16(unsigned char* input, __half *output, int stride)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if(tid < stride * 3)
{
output[0 * stride + tid] = __float2half(input[0 * stride + tid] / 255.0f);
output[1 * stride + tid] = __float2half(input[1 * stride + tid] / 255.0f);
output[2 * stride + tid] = __float2half(input[2 * stride + tid] / 255.0f);
}
}
</code></pre>
<p>And similarly for post-process kernel we use "_half2float" API to convert FP16-->FP32-->UINT8</p>
<pre><code class="language-cpp">
//POST PROCESSING
//1. Apply mask
//2. NCHW->NHWC->uint8
extern "C" __global__ void post_process_fp16(__half* fgr, __half* pha, char* output, int stride)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if(tid < stride)
{
fgr[0 * stride + tid] = fgr[0 * stride + tid] * pha[0 * stride + tid];
fgr[1 * stride + tid] = fgr[1 * stride + tid] * pha[0 * stride + tid];
fgr[2 * stride + tid] = fgr[2 * stride + tid] * pha[0 * stride + tid];
}
__syncthreads();
for(size_t c = 0; c != 3; ++c)
{
float value = __half2float(__hmul(fgr[c * stride + tid], 255));
value = max(0.0f, min(255.0f, value));
output[tid * 3 + c] = static_cast<unsigned char>(value);
}
}
</code></pre>
<p>Then only change needed was the CUDA Kernels and we get the following speed up</p>
<figure>
<img src="step_3.png" width="300" height="400">
<figcaption>Table.3: Running time of each step.</figcaption>
</figure>
<h4>Step-4:[FP16 Batching] GPU Decode, GPU Pre & Post Processing:</h4>
<p>Since FP16 inference makes pre- and post-processing faster, we can offload more work to the GPU by performing these operations in batches. To enable batch-wise inference, we modified the CUDA kernels. The input to the kernels consists of a list of pointers to frames, with the size of the list equal to the “batch_size.” Both the pre-processing and post-processing kernels iterate over each batch and perform the required operations.</p>
<pre><code class="language-cpp">
extern "C" __global__ void pre_process_fp16_batched(uint64_t* frame_ptrs, __half *output,
int height, int width, int channel, int batch_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
int batch_stride = height * width * channel;
int stride = height * width;
if(tid < batch_stride)
{
for(int i=0; i < batch_size; i++)
{
unsigned char* frame = reinterpret_cast<unsigned char*>(frame_ptrs[i]);
output[(0 * stride + tid) + (i * batch_stride)] = __float2half(frame[0 * stride + tid] / 255.0f);
output[(1 * stride + tid) + (i * batch_stride)] = __float2half(frame[1 * stride + tid] / 255.0f);
output[(2 * stride + tid) + (i * batch_stride)] = __float2half(frame[2 * stride + tid] / 255.0f);
}
}
}
</code></pre>
<pre><code class="language-cpp">
//Batched POST PROCESSING
extern "C" __global__ void post_process_fp16_batched(__half* fgr, __half* pha, char* output,
int height, int width, int channel, int batch_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
int stride = height * width;
int batch_stride = height * width * channel;
if(tid < stride)
{
for (int i=0; i < batch_size; i++)
{
fgr[(i * batch_stride) + (0 * stride + tid)] = fgr[(i * batch_stride) + (0 * stride + tid)] * pha[(i * stride) + ( 0 * stride + tid)];
fgr[(i * batch_stride) + (1 * stride + tid)] = fgr[(i * batch_stride) + (1 * stride + tid)] * pha[(i * stride) + ( 0 * stride + tid)];
fgr[(i * batch_stride) + (2 * stride + tid)] = fgr[(i * batch_stride) + (2 * stride + tid)] * pha[(i * stride) + ( 0 * stride + tid)];
}
}
__syncthreads();
for (int i=0; i < batch_size; i++)
{
for(size_t c = 0; c != 3; ++c)
{
float value = __half2float(__hmul(fgr[(i * batch_stride) + (c * stride + tid)], 255));
value = max(0.0f, min(255.0f, value));
output[(i * batch_stride) + (tid * 3 + c)] = static_cast<unsigned char>(value);
}
}
}
</code></pre>
<p>Next, we need to decode the batched frames. The PvNvVideoCodec framework offers the "nvc.ThreadedDecoder()" API, which returns decoded frames according to the specified batch size.</p>
<pre><code class="language-python">
batch_size, frame_channel, frame_height, frame_width = inputs['src']['shape']
decoder = nvc.ThreadedDecoder(enc_file_path=input_video_file,
buffer_size=batch_size,
cuda_context=0,
cuda_stream=0,
use_device_memory=True,
output_color_type=nvc.OutputColorType.RGBP)
</code></pre>
<pre><code class="language-python">
d_frame_size_bytes = frame_channel * frame_height * frame_width * np.dtype(np.uint8).itemsize
d_frame_batch_size_bytes = d_frame_size_bytes * batch_size
d_out_frame = checkCudaErrors(cuda.cuMemAlloc(d_frame_batch_size_bytes))
h_out_frame = np.empty((batch_size, frame_height, frame_width, frame_channel), dtype=np.uint8)
cudaKernelHandle = KernelHelper(cuda_pre_post_kernels_fp16, int(cudaDevice))
cuda_pre_process_kernel = cudaKernelHandle.getFunction(b'pre_process_fp16_batched')
cuda_post_process_kernel = cudaKernelHandle.getFunction(b'post_process_fp16_batched')
</code></pre>
<p>For each frame in the batch, we list the GPU memory, and this list is passed to the kernel for preprocessing.</p>
<pre><code class="language-python">
d_frame_prt_array = None
while True:
frames = decoder.get_batch_frames(batch_size)
if len(frames) == 0:
break
frame_ptr_list = []
for i, frame in enumerate(frames):
frame_device_ptr = int(frame.GetPtrToPlane(0))
frame_ptr_list.append(frame_device_ptr)
frame_ptr_array = np.array(frame_ptr_list, dtype=np.uint64)
if d_frame_prt_array == None:
d_frame_prt_array = checkCudaErrors(cuda.cuMemAlloc(frame_ptr_array.nbytes))
# Copy frame list Device
checkCudaErrors(cuda.cuMemcpyHtoD(d_frame_prt_array, frame_ptr_array, frame_ptr_array.nbytes))
kernel_args_pre_process = ((d_frame_prt_array, inputs['src']['d_mem'], np.int32(frame_height), np.int32(frame_width), np.int32(frame_channel), np.int32(batch_size)),
(None, None, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int))
</code></pre>
<p>The following table shows the performance gain achieved with all optimizations combined with batch processing. The total time can be divided by the batch size to estimate the time for each frame.</p>
<figure>
<img src="step_4.png" width="300" height="400">
<figcaption>Table.4: Running time of each step.</figcaption>
</figure>
<h4>References</h4>
<ul>
<li>PyNvVideoCodec: <a href="https://catalog.ngc.nvidia.com/orgs/nvidia/resources/pynvvideocodec">https://catalog.ngc.nvidia.com/orgs/nvidia/resources/pynvvideocodec</a></li>
<li>CUDA-Python: <a href="https://github.com/NVIDIA/cuda-python">https://github.com/NVIDIA/cuda-python</a></li>
</ul>
</div>
<footer>
© 2025 Mrutunjayya | All rights reserved.
</footer>
<script src="https://unpkg.com/prismjs/prism.js"></script>
<script src="https://unpkg.com/prismjs/components/prism-python.min.js"></script>
</body>
<script>
document.addEventListener("DOMContentLoaded", () => {
document.querySelectorAll("pre > code").forEach(codeBlock => {
const pre = codeBlock.parentNode;
const button = document.createElement("button");
button.className = "copy-btn";
button.textContent = "Copy";
pre.appendChild(button);
button.addEventListener("click", () => {
navigator.clipboard.writeText(codeBlock.innerText).then(() => {
button.textContent = "Copied!";
setTimeout(() => (button.textContent = "Copy"), 1500);
});
});
});
});
</script>
</html>