Skip to content

Commit 47d1da4

Browse files
authored
Add hybrid CPU-GPU DeepStream pipeline (#173)
1 parent 700c835 commit 47d1da4

1 file changed

Lines changed: 149 additions & 0 deletions

File tree

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
from datetime import datetime
3+
from typing import Any, Dict, List
4+
import gi
5+
import numpy as np
6+
import cv2
7+
import pyds
8+
9+
gi.require_version("Gst", "1.0")
10+
from gi.repository import GLib, Gst
11+
12+
from src.frame_comparison.frame_change_detector import FrameChangeDetector
13+
from src.model_conversion.onnx_to_trt import build_engine_if_missing
14+
15+
from src.deepstream.helpers.load_class_labels import load_class_labels
16+
from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor
17+
from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction, SoftmaxTopKClassifier
18+
from src.deepstream.helpers.plant_msg_meta_builder import NvdsPlantEventBuilder, PlantEvent
19+
from src.deepstream.helpers.pipeline_runner import run_pipeline
20+
21+
from src.deepstream.probes.background_removal.cpu.background_removal_probe import remove_background_probe
22+
from src.deepstream.probes.db_message_meta_probe import DbMessageMetaProbe
23+
from src.deepstream.probes.frame_comparison.gpu.frame_skipping_probe import frame_skipping_probe, GPUFrameChangeDetector
24+
25+
# Configuration
26+
RTSP_PORT = os.environ.get("RTSP_PORT", "8554")
27+
RTSP_URL = f"rtsp://127.0.0.1:{RTSP_PORT}/test"
28+
29+
CONFIG_FILE: str = "/workspace/configs/resnet18.txt"
30+
MSGCONV_CONFIG: str = "/workspace/configs/nvmsgbroker_msgconv_config.txt"
31+
MQTT_CONN_STR = "172.17.0.1;1883;agstream-client"
32+
MQTT_TOPIC = "deepstream/predictions"
33+
34+
CLASS_LABELS = load_class_labels()
35+
36+
stats: Dict[str, int] = {"total": 0, "skipped": 0, "processed": 0}
37+
38+
def gpu_frame_skip_probe(pad: Gst.Pad, info: Gst.PadProbeInfo, detector: GPUFrameChangeDetector) -> Gst.PadProbeReturn:
39+
buffer_ptr: int = hash(info.get_buffer())
40+
batch_id: int = 0
41+
should_process: bool = frame_skipping_probe(buffer_ptr, batch_id, detector)
42+
if should_process:
43+
stats["processed"] += 1
44+
print(f"✅ PROCESSING frame {stats['total']}")
45+
return Gst.PadProbeReturn.OK
46+
else:
47+
stats["skipped"] += 1
48+
print(f"⏭️ SKIPPING frame {stats['total']}")
49+
return Gst.PadProbeReturn.DROP
50+
51+
def build_pipeline() -> Gst.Pipeline:
52+
"""Build DeepStream pipeline with background removal, frame skipping, and message broker."""
53+
pipeline = Gst.Pipeline.new("final-cpu-pipeline")
54+
55+
# Elements
56+
rtspsrc = Gst.ElementFactory.make("rtspsrc", "source")
57+
depay = Gst.ElementFactory.make("rtph264depay", "depay")
58+
parse = Gst.ElementFactory.make("h264parse", "parse")
59+
decode = Gst.ElementFactory.make("decodebin", "decode")
60+
convert = Gst.ElementFactory.make("videoconvert", "convert")
61+
nvvideoconvert = Gst.ElementFactory.make("nvvideoconvert", "nvvideoconvert")
62+
capsfilter = Gst.ElementFactory.make("capsfilter", "capsfilter")
63+
streammux = Gst.ElementFactory.make("nvstreammux", "streammux")
64+
nvinfer = Gst.ElementFactory.make("nvinfer", "nvinfer")
65+
nvmsgconv = Gst.ElementFactory.make("nvmsgconv", "nvmsgconv")
66+
nvmsgbroker = Gst.ElementFactory.make("nvmsgbroker", "nvmsgbroker")
67+
68+
for e in [
69+
rtspsrc,
70+
depay,
71+
parse,
72+
decode,
73+
convert,
74+
nvvideoconvert,
75+
capsfilter,
76+
streammux,
77+
nvinfer,
78+
nvmsgconv,
79+
nvmsgbroker,
80+
]:
81+
assert e is not None, f"Failed to create element {e}"
82+
pipeline.add(e)
83+
84+
# Configure elements
85+
rtspsrc.set_property("location", RTSP_URL)
86+
rtspsrc.set_property("latency", 200)
87+
streammux.set_property("batch-size", 1)
88+
streammux.set_property("width", 256)
89+
streammux.set_property("height", 256)
90+
nvinfer.set_property("config-file-path", CONFIG_FILE)
91+
caps = Gst.Caps.from_string("video/x-raw(memory:NVMM), format=RGBA")
92+
capsfilter.set_property("caps", caps)
93+
nvmsgconv.set_property("config", MSGCONV_CONFIG)
94+
nvmsgconv.set_property("payload-type", 0)
95+
nvmsgbroker.set_property("proto-lib", "/opt/nvidia/deepstream/deepstream-6.4/lib/libnvds_mqtt_proto.so")
96+
nvmsgbroker.set_property("conn-str", MQTT_CONN_STR)
97+
nvmsgbroker.set_property("topic", MQTT_TOPIC)
98+
nvmsgbroker.set_property("sync", False)
99+
100+
# Dynamic pad linking
101+
def on_pad_added_rtspsrc(src: Any, pad: Any) -> None:
102+
sinkpad = depay.get_static_pad("sink")
103+
if not sinkpad.is_linked():
104+
pad.link(sinkpad)
105+
106+
rtspsrc.connect("pad-added", on_pad_added_rtspsrc)
107+
108+
def on_pad_added_decode(src: Any, pad: Any) -> None:
109+
sinkpad = convert.get_static_pad("sink")
110+
if not sinkpad.is_linked():
111+
pad.link(sinkpad)
112+
113+
decode.connect("pad-added", on_pad_added_decode)
114+
115+
# Link capsfilter → streammux
116+
depay.link(parse)
117+
parse.link(decode)
118+
convert.link(nvvideoconvert)
119+
nvvideoconvert.link(capsfilter)
120+
srcpad = capsfilter.get_static_pad("src")
121+
sinkpad = streammux.get_request_pad("sink_0")
122+
srcpad.link(sinkpad)
123+
124+
streammux.link(nvinfer)
125+
nvinfer.link(nvmsgconv)
126+
nvmsgconv.link(nvmsgbroker)
127+
128+
detector = GPUFrameChangeDetector()
129+
130+
streammux_src_pad: Gst.Pad = streammux.get_static_pad("src")
131+
streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, gpu_frame_skip_probe, detector)
132+
133+
streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, remove_background_probe)
134+
135+
tensor_extractor = TensorExtractor()
136+
classifier = SoftmaxTopKClassifier(CLASS_LABELS)
137+
plant_event_builder = NvdsPlantEventBuilder()
138+
db_message_meta_probe = DbMessageMetaProbe(tensor_extractor, classifier, plant_event_builder)
139+
140+
nvinfer_src_pad = nvinfer.get_static_pad("src")
141+
nvinfer_src_pad.add_probe(Gst.PadProbeType.BUFFER, db_message_meta_probe.pad_probe)
142+
143+
return pipeline
144+
145+
if __name__ == "__main__":
146+
Gst.init(None)
147+
build_engine_if_missing(CONFIG_FILE)
148+
pipeline = build_pipeline()
149+
run_pipeline(pipeline)

0 commit comments

Comments
 (0)