|
| 1 | +import os |
| 2 | +from datetime import datetime |
| 3 | +from typing import Any, Dict, List |
| 4 | +import gi |
| 5 | +import numpy as np |
| 6 | +import cv2 |
| 7 | +import pyds |
| 8 | + |
| 9 | +gi.require_version("Gst", "1.0") |
| 10 | +from gi.repository import GLib, Gst |
| 11 | + |
| 12 | +from src.frame_comparison.frame_change_detector import FrameChangeDetector |
| 13 | +from src.model_conversion.onnx_to_trt import build_engine_if_missing |
| 14 | + |
| 15 | +from src.deepstream.helpers.load_class_labels import load_class_labels |
| 16 | +from src.deepstream.helpers.meta_tensor_extractor import TensorExtractor |
| 17 | +from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction, SoftmaxTopKClassifier |
| 18 | +from src.deepstream.helpers.plant_msg_meta_builder import NvdsPlantEventBuilder, PlantEvent |
| 19 | +from src.deepstream.helpers.pipeline_runner import run_pipeline |
| 20 | + |
| 21 | +from src.deepstream.probes.background_removal.cpu.background_removal_probe import remove_background_probe |
| 22 | +from src.deepstream.probes.db_message_meta_probe import DbMessageMetaProbe |
| 23 | +from src.deepstream.probes.frame_comparison.gpu.frame_skipping_probe import frame_skipping_probe, GPUFrameChangeDetector |
| 24 | + |
| 25 | +# Configuration |
| 26 | +RTSP_PORT = os.environ.get("RTSP_PORT", "8554") |
| 27 | +RTSP_URL = f"rtsp://127.0.0.1:{RTSP_PORT}/test" |
| 28 | + |
| 29 | +CONFIG_FILE: str = "/workspace/configs/resnet18.txt" |
| 30 | +MSGCONV_CONFIG: str = "/workspace/configs/nvmsgbroker_msgconv_config.txt" |
| 31 | +MQTT_CONN_STR = "172.17.0.1;1883;agstream-client" |
| 32 | +MQTT_TOPIC = "deepstream/predictions" |
| 33 | + |
| 34 | +CLASS_LABELS = load_class_labels() |
| 35 | + |
| 36 | +stats: Dict[str, int] = {"total": 0, "skipped": 0, "processed": 0} |
| 37 | + |
| 38 | +def gpu_frame_skip_probe(pad: Gst.Pad, info: Gst.PadProbeInfo, detector: GPUFrameChangeDetector) -> Gst.PadProbeReturn: |
| 39 | + buffer_ptr: int = hash(info.get_buffer()) |
| 40 | + batch_id: int = 0 |
| 41 | + should_process: bool = frame_skipping_probe(buffer_ptr, batch_id, detector) |
| 42 | + if should_process: |
| 43 | + stats["processed"] += 1 |
| 44 | + print(f"✅ PROCESSING frame {stats['total']}") |
| 45 | + return Gst.PadProbeReturn.OK |
| 46 | + else: |
| 47 | + stats["skipped"] += 1 |
| 48 | + print(f"⏭️ SKIPPING frame {stats['total']}") |
| 49 | + return Gst.PadProbeReturn.DROP |
| 50 | + |
| 51 | +def build_pipeline() -> Gst.Pipeline: |
| 52 | + """Build DeepStream pipeline with background removal, frame skipping, and message broker.""" |
| 53 | + pipeline = Gst.Pipeline.new("final-cpu-pipeline") |
| 54 | + |
| 55 | + # Elements |
| 56 | + rtspsrc = Gst.ElementFactory.make("rtspsrc", "source") |
| 57 | + depay = Gst.ElementFactory.make("rtph264depay", "depay") |
| 58 | + parse = Gst.ElementFactory.make("h264parse", "parse") |
| 59 | + decode = Gst.ElementFactory.make("decodebin", "decode") |
| 60 | + convert = Gst.ElementFactory.make("videoconvert", "convert") |
| 61 | + nvvideoconvert = Gst.ElementFactory.make("nvvideoconvert", "nvvideoconvert") |
| 62 | + capsfilter = Gst.ElementFactory.make("capsfilter", "capsfilter") |
| 63 | + streammux = Gst.ElementFactory.make("nvstreammux", "streammux") |
| 64 | + nvinfer = Gst.ElementFactory.make("nvinfer", "nvinfer") |
| 65 | + nvmsgconv = Gst.ElementFactory.make("nvmsgconv", "nvmsgconv") |
| 66 | + nvmsgbroker = Gst.ElementFactory.make("nvmsgbroker", "nvmsgbroker") |
| 67 | + |
| 68 | + for e in [ |
| 69 | + rtspsrc, |
| 70 | + depay, |
| 71 | + parse, |
| 72 | + decode, |
| 73 | + convert, |
| 74 | + nvvideoconvert, |
| 75 | + capsfilter, |
| 76 | + streammux, |
| 77 | + nvinfer, |
| 78 | + nvmsgconv, |
| 79 | + nvmsgbroker, |
| 80 | + ]: |
| 81 | + assert e is not None, f"Failed to create element {e}" |
| 82 | + pipeline.add(e) |
| 83 | + |
| 84 | + # Configure elements |
| 85 | + rtspsrc.set_property("location", RTSP_URL) |
| 86 | + rtspsrc.set_property("latency", 200) |
| 87 | + streammux.set_property("batch-size", 1) |
| 88 | + streammux.set_property("width", 256) |
| 89 | + streammux.set_property("height", 256) |
| 90 | + nvinfer.set_property("config-file-path", CONFIG_FILE) |
| 91 | + caps = Gst.Caps.from_string("video/x-raw(memory:NVMM), format=RGBA") |
| 92 | + capsfilter.set_property("caps", caps) |
| 93 | + nvmsgconv.set_property("config", MSGCONV_CONFIG) |
| 94 | + nvmsgconv.set_property("payload-type", 0) |
| 95 | + nvmsgbroker.set_property("proto-lib", "/opt/nvidia/deepstream/deepstream-6.4/lib/libnvds_mqtt_proto.so") |
| 96 | + nvmsgbroker.set_property("conn-str", MQTT_CONN_STR) |
| 97 | + nvmsgbroker.set_property("topic", MQTT_TOPIC) |
| 98 | + nvmsgbroker.set_property("sync", False) |
| 99 | + |
| 100 | + # Dynamic pad linking |
| 101 | + def on_pad_added_rtspsrc(src: Any, pad: Any) -> None: |
| 102 | + sinkpad = depay.get_static_pad("sink") |
| 103 | + if not sinkpad.is_linked(): |
| 104 | + pad.link(sinkpad) |
| 105 | + |
| 106 | + rtspsrc.connect("pad-added", on_pad_added_rtspsrc) |
| 107 | + |
| 108 | + def on_pad_added_decode(src: Any, pad: Any) -> None: |
| 109 | + sinkpad = convert.get_static_pad("sink") |
| 110 | + if not sinkpad.is_linked(): |
| 111 | + pad.link(sinkpad) |
| 112 | + |
| 113 | + decode.connect("pad-added", on_pad_added_decode) |
| 114 | + |
| 115 | + # Link capsfilter → streammux |
| 116 | + depay.link(parse) |
| 117 | + parse.link(decode) |
| 118 | + convert.link(nvvideoconvert) |
| 119 | + nvvideoconvert.link(capsfilter) |
| 120 | + srcpad = capsfilter.get_static_pad("src") |
| 121 | + sinkpad = streammux.get_request_pad("sink_0") |
| 122 | + srcpad.link(sinkpad) |
| 123 | + |
| 124 | + streammux.link(nvinfer) |
| 125 | + nvinfer.link(nvmsgconv) |
| 126 | + nvmsgconv.link(nvmsgbroker) |
| 127 | + |
| 128 | + detector = GPUFrameChangeDetector() |
| 129 | + |
| 130 | + streammux_src_pad: Gst.Pad = streammux.get_static_pad("src") |
| 131 | + streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, gpu_frame_skip_probe, detector) |
| 132 | + |
| 133 | + streammux_src_pad.add_probe(Gst.PadProbeType.BUFFER, remove_background_probe) |
| 134 | + |
| 135 | + tensor_extractor = TensorExtractor() |
| 136 | + classifier = SoftmaxTopKClassifier(CLASS_LABELS) |
| 137 | + plant_event_builder = NvdsPlantEventBuilder() |
| 138 | + db_message_meta_probe = DbMessageMetaProbe(tensor_extractor, classifier, plant_event_builder) |
| 139 | + |
| 140 | + nvinfer_src_pad = nvinfer.get_static_pad("src") |
| 141 | + nvinfer_src_pad.add_probe(Gst.PadProbeType.BUFFER, db_message_meta_probe.pad_probe) |
| 142 | + |
| 143 | + return pipeline |
| 144 | + |
| 145 | +if __name__ == "__main__": |
| 146 | + Gst.init(None) |
| 147 | + build_engine_if_missing(CONFIG_FILE) |
| 148 | + pipeline = build_pipeline() |
| 149 | + run_pipeline(pipeline) |
0 commit comments