Skip to content

Commit 700c835

Browse files
Sarah5567r83575
andauthored
Final gpu pipeline (#171)
* basic implementation * Finalize pipeline and verify runtime execution Co-authored-by: Sarah Gershuni <sarah556726@gmail.com> Co-authored-by: ruti_cohen <r0583283575@gmail.com> --------- Co-authored-by: r83575 <r0583283575@gmail.com>
1 parent 05fa9db commit 700c835

3 files changed

Lines changed: 188 additions & 2 deletions

File tree

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import os
2+
from datetime import datetime
3+
from typing import Any, Dict, List
4+
import gi
5+
import numpy as np
6+
import cv2
7+
import pyds
8+
9+
gi.require_version("Gst", "1.0")
10+
from gi.repository import GLib, Gst
11+
12+
from src.frame_comparison.frame_change_detector import FrameChangeDetector
13+
from src.model_conversion.onnx_to_trt import build_engine_if_missing
14+
15+
from src.deepstream.helpers.load_class_labels import load_class_labels
16+
from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor
17+
from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction, SoftmaxTopKClassifier
18+
from src.deepstream.helpers.plant_msg_meta_builder import NvdsPlantEventBuilder, PlantEvent
19+
from src.deepstream.helpers.pipeline_runner import run_pipeline
20+
21+
from src.deepstream.probes.background_removal.gpu import gpu_background_probe
22+
from src.deepstream.probes.db_message_meta_probe import DbMessageMetaProbe
23+
from src.deepstream.probes.frame_comparison.gpu.frame_skipping_probe import frame_skipping_probe, GPUFrameChangeDetector
24+
25+
# Configuration
26+
RTSP_PORT = os.environ.get("RTSP_PORT", "8554")
27+
RTSP_URL = f"rtsp://127.0.0.1:{RTSP_PORT}/test"
28+
29+
CONFIG_FILE: str = "/workspace/configs/resnet18.txt"
30+
MSGCONV_CONFIG: str = "/workspace/configs/nvmsgbroker_msgconv_config.txt"
31+
MQTT_CONN_STR = "172.17.0.1;1883;agstream-client"
32+
MQTT_TOPIC = "deepstream/predictions"
33+
34+
CLASS_LABELS = load_class_labels()
35+
36+
stats: Dict[str, int] = {"total": 0, "skipped": 0, "processed": 0}
37+
38+
def gpu_frame_skip_probe(pad: Gst.Pad, info: Gst.PadProbeInfo, detector: GPUFrameChangeDetector) -> Gst.PadProbeReturn:
39+
buffer_ptr: int = hash(info.get_buffer())
40+
batch_id: int = 0
41+
should_process: bool = frame_skipping_probe(buffer_ptr, batch_id, detector)
42+
if should_process:
43+
stats["processed"] += 1
44+
print(f"✅ PROCESSING frame {stats['total']}")
45+
return Gst.PadProbeReturn.OK
46+
else:
47+
stats["skipped"] += 1
48+
print(f"⏭️ SKIPPING frame {stats['total']}")
49+
return Gst.PadProbeReturn.DROP
50+
51+
def _post_error_from_pad(pad: Gst.Pad, text: str, debug: str = "") -> None:
52+
"""Post an error message to the GStreamer bus from a pad."""
53+
elem = pad.get_parent_element()
54+
if not elem:
55+
return
56+
gerr = GLib.Error(text)
57+
msg = Gst.Message.new_error(elem, gerr, debug)
58+
elem.post_message(msg)
59+
60+
def background_removal_gpu_probe(pad: Gst.Pad, info: Gst.PadProbeInfo) -> Gst.PadProbeReturn:
61+
"""GPU pad probe that applies background removal to each frame."""
62+
gst_buffer = info.get_buffer()
63+
if not gst_buffer:
64+
_post_error_from_pad(pad, "missing GstBuffer in probe", "background_removal_gpu_probe: buf is None")
65+
if STRICT_ERRORS:
66+
raise RuntimeError("background_removal_gpu_probe: missing GstBuffer")
67+
return Gst.PadProbeReturn.DROP
68+
69+
try:
70+
buffer_ptr: int = hash(gst_buffer)
71+
batch_id: int = 0
72+
73+
success: bool = gpu_background_probe.remove_background_gpu_probe_pipeline(buffer_ptr, batch_id)
74+
if not success:
75+
_post_error_from_pad(pad, "GPU background removal failed")
76+
if STRICT_ERRORS:
77+
raise RuntimeError("GPU background removal failed")
78+
return Gst.PadProbeReturn.DROP
79+
80+
except Exception as e:
81+
_post_error_from_pad(pad, "GPU processing error", str(e))
82+
if STRICT_ERRORS:
83+
raise
84+
return Gst.PadProbeReturn.DROP
85+
86+
return Gst.PadProbeReturn.OK
87+
88+
def build_pipeline() -> Gst.Pipeline:
89+
"""Build DeepStream pipeline with background removal, frame skipping, and message broker."""
90+
pipeline = Gst.Pipeline.new("final-cpu-pipeline")
91+
92+
# Elements
93+
rtspsrc = Gst.ElementFactory.make("rtspsrc", "source")
94+
depay = Gst.ElementFactory.make("rtph264depay", "depay")
95+
parse = Gst.ElementFactory.make("h264parse", "parse")
96+
decode = Gst.ElementFactory.make("decodebin", "decode")
97+
convert = Gst.ElementFactory.make("videoconvert", "convert")
98+
nvvideoconvert = Gst.ElementFactory.make("nvvideoconvert", "nvvideoconvert")
99+
capsfilter = Gst.ElementFactory.make("capsfilter", "capsfilter")
100+
streammux = Gst.ElementFactory.make("nvstreammux", "streammux")
101+
nvinfer = Gst.ElementFactory.make("nvinfer", "nvinfer")
102+
nvmsgconv = Gst.ElementFactory.make("nvmsgconv", "nvmsgconv")
103+
nvmsgbroker = Gst.ElementFactory.make("nvmsgbroker", "nvmsgbroker")
104+
105+
for e in [
106+
rtspsrc,
107+
depay,
108+
parse,
109+
decode,
110+
convert,
111+
nvvideoconvert,
112+
capsfilter,
113+
streammux,
114+
nvinfer,
115+
nvmsgconv,
116+
nvmsgbroker,
117+
]:
118+
assert e is not None, f"Failed to create element {e}"
119+
pipeline.add(e)
120+
121+
# Configure elements
122+
rtspsrc.set_property("location", RTSP_URL)
123+
rtspsrc.set_property("latency", 200)
124+
streammux.set_property("batch-size", 1)
125+
streammux.set_property("width", 256)
126+
streammux.set_property("height", 256)
127+
nvinfer.set_property("config-file-path", CONFIG_FILE)
128+
caps = Gst.Caps.from_string("video/x-raw(memory:NVMM), format=RGBA")
129+
capsfilter.set_property("caps", caps)
130+
nvmsgconv.set_property("config", MSGCONV_CONFIG)
131+
nvmsgconv.set_property("payload-type", 0)
132+
nvmsgbroker.set_property("proto-lib", "/opt/nvidia/deepstream/deepstream-6.4/lib/libnvds_mqtt_proto.so")
133+
nvmsgbroker.set_property("conn-str", MQTT_CONN_STR)
134+
nvmsgbroker.set_property("topic", MQTT_TOPIC)
135+
nvmsgbroker.set_property("sync", False)
136+
137+
# Dynamic pad linking
138+
def on_pad_added_rtspsrc(src: Any, pad: Any) -> None:
139+
sinkpad = depay.get_static_pad("sink")
140+
if not sinkpad.is_linked():
141+
pad.link(sinkpad)
142+
143+
rtspsrc.connect("pad-added", on_pad_added_rtspsrc)
144+
145+
def on_pad_added_decode(src: Any, pad: Any) -> None:
146+
sinkpad = convert.get_static_pad("sink")
147+
if not sinkpad.is_linked():
148+
pad.link(sinkpad)
149+
150+
decode.connect("pad-added", on_pad_added_decode)
151+
152+
# Link capsfilter → streammux
153+
depay.link(parse)
154+
parse.link(decode)
155+
convert.link(nvvideoconvert)
156+
nvvideoconvert.link(capsfilter)
157+
srcpad = capsfilter.get_static_pad("src")
158+
sinkpad = streammux.get_request_pad("sink_0")
159+
srcpad.link(sinkpad)
160+
161+
streammux.link(nvinfer)
162+
nvinfer.link(nvmsgconv)
163+
nvmsgconv.link(nvmsgbroker)
164+
165+
detector = GPUFrameChangeDetector()
166+
167+
streammux_src_pad: Gst.Pad = streammux.get_static_pad("src")
168+
streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, gpu_frame_skip_probe, detector)
169+
170+
streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, background_removal_gpu_probe)
171+
172+
tensor_extractor = TensorExtractor()
173+
classifier = SoftmaxTopKClassifier(CLASS_LABELS)
174+
plant_event_builder = NvdsPlantEventBuilder()
175+
db_message_meta_probe = DbMessageMetaProbe(tensor_extractor, classifier, plant_event_builder)
176+
177+
nvinfer_src_pad = nvinfer.get_static_pad("src")
178+
nvinfer_src_pad.add_probe(Gst.PadProbeType.BUFFER, db_message_meta_probe.pad_probe)
179+
180+
return pipeline
181+
182+
if __name__ == "__main__":
183+
Gst.init(None)
184+
build_engine_if_missing(CONFIG_FILE)
185+
pipeline = build_pipeline()
186+
run_pipeline(pipeline)

src/deepstream/probes/frame_comparison/gpu/frame_change_detector.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include <cmath>
66
#include <vector>
77

8-
const double MSE_THRESH = 20.0;
9-
const double SSIM_THRESH = 0.998;
8+
const double MSE_THRESH = 99.0;
9+
const double SSIM_THRESH = 0.996;
1010
const double FLOW_THRESH = 0.5;
1111
const double OPTICAL_FLOW_ACTIVE_THRESH = 0.5;
1212

Binary file not shown.

0 commit comments

Comments
 (0)