Skip to content

Commit 978c8eb

Browse files
authored
Final cpu pipeline (#169)
* basic implementation of deepstream_pipeline_cpu.py * full implementation for square videos * update output resolution to 256x256 to match model training size * remove unused probe * extract shared helper, remove comments, and apply linting * move load_class_labels to shared helper and update all files using it
1 parent 8a5499a commit 978c8eb

10 files changed

Lines changed: 322 additions & 37 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import List
2+
3+
PLANT_LABELS = "/workspace/configs/crop_and_weed_83_classes.txt"
4+
5+
# Load class labels
6+
def load_class_labels() -> List[str]:
7+
try:
8+
with open(PLANT_LABELS, "r") as f:
9+
return [line.strip() for line in f.readlines()]
10+
except:
11+
return [f"class_{i}" for i in range(83)]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import sys
2+
from pydantic import BaseModel
3+
4+
import gi
5+
gi.require_version("Gst", "1.0")
6+
7+
from gi.repository import Gst
8+
import pyds
9+
10+
from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction
11+
12+
class PlantEvent(BaseModel):
13+
frame_id: int
14+
plant_id: str
15+
prediction: ClassificationPrediction
16+
17+
18+
class NvdsPlantEventBuilder:
19+
"""
20+
For the reference, please check
21+
https://github.com/NVIDIA-AI-IOT/deepstream_python_apps/blob/9b27f02ffea46a3ded2ad26b3eea27ef3e2dfded/apps/deepstream-test4/deepstream_test_4.py
22+
"""
23+
24+
def build(self, batch_meta, frame_meta, event: PlantEvent) -> None:
25+
user_event_meta = pyds.nvds_acquire_user_meta_from_pool(batch_meta)
26+
msg_meta = pyds.NvDsEventMsgMeta.cast(pyds.alloc_nvds_event_msg_meta(user_event_meta))
27+
28+
# Generic IDs / bookkeeping
29+
msg_meta.frameId = event.frame_id
30+
msg_meta.objectId = event.plant_id
31+
msg_meta.trackingId = event.frame_id
32+
msg_meta.objClassId = event.prediction.class_id
33+
msg_meta.confidence = float(event.prediction.confidence)
34+
35+
# Link to sensor0 / place0 / analytics0 from msgconv config
36+
msg_meta.sensorId = 0
37+
msg_meta.placeId = 0
38+
msg_meta.moduleId = 0
39+
40+
msg_meta.type = pyds.NvDsEventType.NVDS_EVENT_CUSTOM
41+
msg_meta.objType = pyds.NvDsObjectType.NVDS_OBJECT_TYPE_PERSON
42+
43+
# Timestamp
44+
msg_meta.ts = pyds.alloc_buffer(32)
45+
pyds.generate_ts_rfc3339(msg_meta.ts, 32)
46+
47+
# Dummy bbox (positive; you can plug in real ROIs later)
48+
msg_meta.bbox.top = -1
49+
msg_meta.bbox.left = -1
50+
msg_meta.bbox.width = 0
51+
msg_meta.bbox.height = 0
52+
53+
# Attach person object repurposed to carry plant info
54+
obj = pyds.alloc_nvds_person_object()
55+
person = pyds.NvDsPersonObject.cast(obj)
56+
57+
person.age = 0
58+
person.gender = "plant"
59+
person.hair = "none"
60+
person.cap = "none"
61+
person.apparel = event.prediction.class_name
62+
63+
msg_meta.extMsg = obj
64+
msg_meta.extMsgSize = sys.getsizeof(pyds.NvDsPersonObject)
65+
66+
user_event_meta.user_meta_data = msg_meta
67+
user_event_meta.base_meta.meta_type = pyds.NvDsMetaType.NVDS_EVENT_MSG_META
68+
pyds.nvds_add_user_meta_to_frame(frame_meta, user_event_meta)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
import cv2
3+
4+
MORPH_KERNEL = 34
5+
6+
def remove_background(frame_bgr: np.ndarray) -> np.ndarray:
7+
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
8+
R, G, B = rgb[..., 0], rgb[..., 1], rgb[..., 2]
9+
10+
exg = 2.0 * G - R - B
11+
exg_norm = cv2.normalize(exg, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
12+
13+
_, mask = cv2.threshold(exg_norm, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
14+
15+
kernel = cv2.getStructuringElement(
16+
cv2.MORPH_ELLIPSE,
17+
(MORPH_KERNEL, MORPH_KERNEL)
18+
)
19+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
20+
21+
out = np.zeros_like(frame_bgr)
22+
out[mask > 0] = frame_bgr[mask > 0]
23+
return out
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import os
2+
from datetime import datetime
3+
from typing import Any, Dict, List
4+
import gi
5+
import numpy as np
6+
import cv2
7+
import pyds
8+
9+
gi.require_version("Gst", "1.0")
10+
from gi.repository import GLib, Gst
11+
12+
from src.frame_comparison.frame_change_detector import FrameChangeDetector
13+
from src.model_conversion.onnx_to_trt import build_engine_if_missing
14+
15+
from src.deepstream.helpers.load_class_labels import load_class_labels
16+
from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor
17+
from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction, SoftmaxTopKClassifier
18+
from src.deepstream.helpers.plant_msg_meta_builder import NvdsPlantEventBuilder, PlantEvent
19+
from src.deepstream.helpers.pipeline_runner import run_pipeline
20+
21+
from src.deepstream.probes.background_removal.cpu.background_removal_probe import remove_background_probe
22+
from src.deepstream.probes.db_message_meta_probe import DbMessageMetaProbe
23+
from src.deepstream.probes.frame_comparison.cpu.frame_skipping_probe import frame_skip_probe
24+
25+
# Configuration
26+
RTSP_PORT = os.environ.get("RTSP_PORT", "8554")
27+
RTSP_URL = f"rtsp://127.0.0.1:{RTSP_PORT}/test"
28+
29+
CONFIG_FILE: str = "/workspace/configs/resnet18.txt"
30+
MSGCONV_CONFIG: str = "/workspace/configs/nvmsgbroker_msgconv_config.txt"
31+
MQTT_CONN_STR = "172.17.0.1;1883;agstream-client"
32+
MQTT_TOPIC = "deepstream/predictions"
33+
34+
CLASS_LABELS = load_class_labels()
35+
36+
def build_pipeline() -> Gst.Pipeline:
37+
"""Build DeepStream pipeline with background removal, frame skipping, and message broker."""
38+
pipeline = Gst.Pipeline.new("final-cpu-pipeline")
39+
40+
# Elements
41+
rtspsrc = Gst.ElementFactory.make("rtspsrc", "source")
42+
depay = Gst.ElementFactory.make("rtph264depay", "depay")
43+
parse = Gst.ElementFactory.make("h264parse", "parse")
44+
decode = Gst.ElementFactory.make("decodebin", "decode")
45+
convert = Gst.ElementFactory.make("videoconvert", "convert")
46+
nvvideoconvert = Gst.ElementFactory.make("nvvideoconvert", "nvvideoconvert")
47+
capsfilter = Gst.ElementFactory.make("capsfilter", "capsfilter")
48+
streammux = Gst.ElementFactory.make("nvstreammux", "streammux")
49+
nvinfer = Gst.ElementFactory.make("nvinfer", "nvinfer")
50+
nvmsgconv = Gst.ElementFactory.make("nvmsgconv", "nvmsgconv")
51+
nvmsgbroker = Gst.ElementFactory.make("nvmsgbroker", "nvmsgbroker")
52+
53+
for e in [
54+
rtspsrc,
55+
depay,
56+
parse,
57+
decode,
58+
convert,
59+
nvvideoconvert,
60+
capsfilter,
61+
streammux,
62+
nvinfer,
63+
nvmsgconv,
64+
nvmsgbroker,
65+
]:
66+
assert e is not None, f"Failed to create element {e}"
67+
pipeline.add(e)
68+
69+
# Configure elements
70+
rtspsrc.set_property("location", RTSP_URL)
71+
rtspsrc.set_property("latency", 200)
72+
streammux.set_property("batch-size", 1)
73+
streammux.set_property("width", 256)
74+
streammux.set_property("height", 256)
75+
nvinfer.set_property("config-file-path", CONFIG_FILE)
76+
caps = Gst.Caps.from_string("video/x-raw(memory:NVMM), format=RGBA")
77+
capsfilter.set_property("caps", caps)
78+
nvmsgconv.set_property("config", MSGCONV_CONFIG)
79+
nvmsgconv.set_property("payload-type", 0)
80+
nvmsgbroker.set_property("proto-lib", "/opt/nvidia/deepstream/deepstream-6.4/lib/libnvds_mqtt_proto.so")
81+
nvmsgbroker.set_property("conn-str", MQTT_CONN_STR)
82+
nvmsgbroker.set_property("topic", MQTT_TOPIC)
83+
nvmsgbroker.set_property("sync", False)
84+
85+
# Dynamic pad linking
86+
def on_pad_added_rtspsrc(src: Any, pad: Any) -> None:
87+
sinkpad = depay.get_static_pad("sink")
88+
if not sinkpad.is_linked():
89+
pad.link(sinkpad)
90+
91+
rtspsrc.connect("pad-added", on_pad_added_rtspsrc)
92+
93+
def on_pad_added_decode(src: Any, pad: Any) -> None:
94+
sinkpad = convert.get_static_pad("sink")
95+
if not sinkpad.is_linked():
96+
pad.link(sinkpad)
97+
98+
decode.connect("pad-added", on_pad_added_decode)
99+
100+
# Link capsfilter → streammux
101+
depay.link(parse)
102+
parse.link(decode)
103+
convert.link(nvvideoconvert)
104+
nvvideoconvert.link(capsfilter)
105+
srcpad = capsfilter.get_static_pad("src")
106+
sinkpad = streammux.get_request_pad("sink_0")
107+
srcpad.link(sinkpad)
108+
109+
streammux.link(nvinfer)
110+
nvinfer.link(nvmsgconv)
111+
nvmsgconv.link(nvmsgbroker)
112+
113+
frame_change_detector = FrameChangeDetector()
114+
115+
streammux_src_pad = streammux.get_static_pad("src")
116+
streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, frame_skip_probe, frame_change_detector)
117+
118+
streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, remove_background_probe)
119+
120+
tensor_extractor = TensorExtractor()
121+
classifier = SoftmaxTopKClassifier(CLASS_LABELS)
122+
plant_event_builder = NvdsPlantEventBuilder()
123+
db_message_meta_probe = DbMessageMetaProbe(tensor_extractor, classifier, plant_event_builder)
124+
125+
nvinfer_src_pad = nvinfer.get_static_pad("src")
126+
nvinfer_src_pad.add_probe(Gst.PadProbeType.BUFFER, db_message_meta_probe.pad_probe)
127+
128+
return pipeline
129+
130+
if __name__ == "__main__":
131+
Gst.init(None)
132+
build_engine_if_missing(CONFIG_FILE)
133+
pipeline = build_pipeline()
134+
run_pipeline(pipeline)

src/deepstream/pipelines/nvmsgbroker_pipeline.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
gi.require_version("Gst", "1.0")
1111

1212

13+
from src.deepstream.helpers.load_class_labels import load_class_labels
1314
from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor
1415
from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction, SoftmaxTopKClassifier
1516
from src.model_conversion.onnx_to_trt import build_engine_if_missing
@@ -20,17 +21,6 @@
2021

2122
CONFIG_FILE = "/workspace/configs/mobilenet.txt"
2223
MSGCONV_CONFIG = "/workspace/configs/nvmsgbroker_msgconv_config.txt"
23-
CLASS_LABELS_FILE = "/workspace/configs/crop_and_weed_83_classes.txt"
24-
25-
26-
def load_class_labels() -> List[str]:
27-
try:
28-
with open(CLASS_LABELS_FILE) as f:
29-
return [line.strip() for line in f]
30-
except Exception:
31-
# Fallback to dummy labels if file is missing
32-
return [f"class_{i}" for i in range(83)]
33-
3424

3525
CLASS_LABELS = load_class_labels()
3626

src/deepstream/pipelines/pipeline_onnx_real_input.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,13 @@
88
gi.require_version("Gst", "1.0")
99
from gi.repository import GLib, Gst
1010

11+
from src.deepstream.helpers.load_class_labels import load_class_labels
1112
from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor
1213
from src.deepstream.helpers.softmax_topk_classifier import (
1314
SoftmaxTopKClassifier,
1415
)
1516
from src.deepstream.helpers.pipeline_runner import run_pipeline
1617

17-
CLASS_LABELS_FILE = "/workspace/configs/crop_and_weed_83_classes.txt"
18-
19-
20-
def load_class_labels() -> List[str]:
21-
try:
22-
with open(CLASS_LABELS_FILE) as f:
23-
return [line.strip() for line in f]
24-
except Exception:
25-
return [f"class_{i}" for i in range(83)]
26-
2718

2819
def list_image_files(folder: str) -> List[str]:
2920
exts = (".jpg", ".jpeg", ".png")

src/deepstream/pipelines/pipeline_onnx_test.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pyds
99
from gi.repository import Gst # noqa: E402
1010

11+
from src.deepstream.helpers.load_class_labels import load_class_labels
1112
from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor
1213
from src.deepstream.helpers.softmax_topk_classifier import (
1314
ClassificationPrediction,
@@ -16,18 +17,6 @@
1617
from src.model_conversion.onnx_to_trt import build_engine_if_missing
1718
from src.deepstream.helpers.pipeline_runner import run_pipeline
1819

19-
CLASS_LABELS_FILE = "/workspace/configs/crop_and_weed_83_classes.txt"
20-
21-
22-
def load_class_labels() -> List[str]:
23-
try:
24-
with open(CLASS_LABELS_FILE) as f:
25-
return [line.strip() for line in f]
26-
except Exception:
27-
# Fallback to dummy labels if file is missing
28-
return [f"class_{i}" for i in range(83)]
29-
30-
3120
CLASS_LABELS = load_class_labels()
3221

3322
# Reuse the tensor extractor + softmax top-k classifier from your RTSP pipeline
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import cv2
2+
import numpy as np
3+
import pyds
4+
from gi.repository import Gst
5+
6+
from src.deepstream.helpers.remove_background import remove_background
7+
8+
9+
def remove_background_probe(pad: Gst.Pad, info: Gst.PadProbeInfo) -> Gst.PadProbeReturn:
10+
gst_buffer = info.get_buffer()
11+
if not gst_buffer:
12+
return Gst.PadProbeReturn.OK
13+
14+
try:
15+
surface = pyds.get_nvds_buf_surface(hash(gst_buffer), 0)
16+
frame_rgba = np.array(surface, copy=False, order="C")
17+
frame_bgr = cv2.cvtColor(frame_rgba, cv2.COLOR_RGBA2BGR)
18+
19+
masked_bgr = remove_background(frame_bgr)
20+
frame_rgba_new = cv2.cvtColor(masked_bgr, cv2.COLOR_BGR2RGBA)
21+
22+
np.copyto(frame_rgba, frame_rgba_new)
23+
except Exception as e:
24+
raise RuntimeError("BackgroundRemovalProbe: failed to fetch/map NvBufSurface") from e
25+
26+
return Gst.PadProbeReturn.OK

0 commit comments

Comments
 (0)