Skip to content

Commit 14595ae

Browse files
authored
Project cleanup (#174)
* Add linting and formatting checks * remove unused imports
1 parent e171d3c commit 14595ae

46 files changed

Lines changed: 256 additions & 261 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

scripts/export_to_onnx.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
import os
2-
import sys
31
import argparse
42
from pathlib import Path
53

64
import torch
75

8-
from src.inference.utils.inference_factory import InferenceFactory, InferenceConfig, Backend, ModelArch
6+
from src.inference.utils.inference_factory import Backend, InferenceConfig, InferenceFactory, ModelArch
97
from src.path_utils import ensure_clean_directory
108

119
NUM_CLASSES = 83
1210
INPUT_SIZE = (256, 256)
1311
DEVICE = "cpu"
1412
MODELS_DIR_PATH = Path("models")
1513

14+
1615
def main(model_name: str):
1716
arch = ModelArch.RESNET if "resnet" in model_name.lower() else ModelArch.MOBILENET
1817
backend = Backend.PYTORCH

scripts/onnx_validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import os
2-
import argparse
32
from pathlib import Path
43

54
import numpy as np
65
import torch
76
from PIL import Image
87

98
from dataset.optimal_class_mapping import MODEL_NAMES as ID_TO_NAME
10-
from src.inference.utils.inference_factory import InferenceFactory, InferenceConfig, Backend, ModelArch
119
from src.inference.base.classifier_inference_base import ClassifierInferenceBase
1210
from src.inference.base.classifier_inference_base_onnx import OnnxClassifierInferenceBase
11+
from src.inference.utils.inference_factory import Backend, InferenceConfig, InferenceFactory, ModelArch
1312

1413
NUM_CLASSES = 83
1514
INPUT_SIZE = (256, 256)

services/consumer/consumer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1-
import paho.mqtt.client as mqtt
2-
import pymongo
31
import json
42
from datetime import datetime
53

4+
import paho.mqtt.client as mqtt
5+
import pymongo
6+
7+
68
def on_connect(client, userdata, flags, rc):
79
print(f"🔗 Connected: {rc}")
810
client.subscribe("deepstream/predictions")
911

12+
1013
def on_message(client, userdata, msg):
1114
try:
1215
payload = json.loads(msg.payload.decode())
1316
payload["received_at"] = datetime.now().isoformat()
14-
17+
1518
mongo_client = pymongo.MongoClient("mongodb://agstream_mongo:27017/")
1619
db = mongo_client["agstream"]
1720
collection = db["predictions"]
18-
21+
1922
collection.insert_one(payload)
20-
print(f"💾 Saved to MongoDB!")
23+
print("💾 Saved to MongoDB!")
2124
mongo_client.close()
22-
25+
2326
except Exception as e:
2427
print(f"❌ Error: {e}")
2528

29+
2630
client = mqtt.Client()
2731
client.on_connect = on_connect
2832
client.on_message = on_message

services/consumer/mqtt_consumer.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,60 @@
11
#!/usr/bin/env python3
2-
import paho.mqtt.client as mqtt
3-
import pymongo
42
import json
53
from datetime import datetime
64

5+
import paho.mqtt.client as mqtt
6+
import pymongo
7+
78
MQTT_BROKER = "agstream_mosquitto"
89
MQTT_PORT = 1883
910
MQTT_TOPIC = "deepstream/predictions"
1011
MONGO_URI = "mongodb://agstream_mongo:27017/"
1112
MONGO_DB = "agstream"
1213
MONGO_COLLECTION = "predictions"
1314

15+
1416
def on_connect(client, userdata, flags, rc):
1517
print(f"✅ Connected to MQTT broker with result code {rc}")
1618
client.subscribe(MQTT_TOPIC)
1719

20+
1821
def on_message(client, userdata, msg):
1922
try:
2023
payload = json.loads(msg.payload.decode())
2124
payload["received_at"] = datetime.now().isoformat()
22-
25+
2326
# Extract classification data from nvmsgbroker format
2427
obj = payload.get("object", {})
2528
if obj.get("id") != "0" and obj.get("id"):
2629
class_id = int(obj["id"])
2730
confidence = obj.get("confidence", 0)
28-
31+
2932
# Add extracted classification
30-
payload["classification"] = {
31-
"class_id": class_id,
32-
"confidence": confidence
33-
}
33+
payload["classification"] = {"class_id": class_id, "confidence": confidence}
3434
print(f"🌱 FOUND CLASSIFICATION: ID {class_id}, confidence {confidence:.3f}")
35-
35+
3636
mongo_client = pymongo.MongoClient(MONGO_URI)
3737
db = mongo_client[MONGO_DB]
3838
collection = db[MONGO_COLLECTION]
39-
39+
4040
result = collection.insert_one(payload)
41-
41+
4242
if "classification" in payload:
43-
print(f"✅ Saved classification to MongoDB!")
43+
print("✅ Saved classification to MongoDB!")
4444
else:
45-
print(f"✅ Saved: No classification")
46-
45+
print("✅ Saved: No classification")
46+
4747
mongo_client.close()
48-
48+
4949
except Exception as e:
5050
print(f"❌ Error: {e}")
5151

52+
5253
if __name__ == "__main__":
5354
client = mqtt.Client()
5455
client.on_connect = on_connect
5556
client.on_message = on_message
56-
57+
5758
print("🚀 Starting Enhanced MQTT Consumer...")
5859
client.connect(MQTT_BROKER, MQTT_PORT, 60)
5960
client.loop_forever()

src/deepstream/helpers/load_class_labels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
PLANT_LABELS = "/workspace/configs/crop_and_weed_83_classes.txt"
44

5+
56
# Load class labels
67
def load_class_labels() -> List[str]:
78
try:

src/deepstream/helpers/meta_tensor_extractor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import ctypes
2-
import numpy.typing as npt
2+
33
import numpy as np
4+
import numpy.typing as npt
45
import pyds
56

7+
68
class TensorExtractor:
79
def extract_logits(self, tensor_meta) -> npt.NDArray[np.float32]:
810
"""
@@ -13,10 +15,7 @@ def extract_logits(self, tensor_meta) -> npt.NDArray[np.float32]:
1315
dims = [layer.dims.d[i] for i in range(layer.dims.numDims)]
1416
numel = int(np.prod(dims))
1517

16-
ptr = ctypes.cast(
17-
pyds.get_ptr(layer.buffer),
18-
ctypes.POINTER(ctypes.c_float)
19-
)
18+
ptr = ctypes.cast(pyds.get_ptr(layer.buffer), ctypes.POINTER(ctypes.c_float))
2019
logits = np.ctypeslib.as_array(ptr, shape=(numel,))
2120

2221
# Copy so we are not tied to DeepStream's memory lifetime

src/deepstream/helpers/plant_msg_meta_builder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import sys
2-
from pydantic import BaseModel
32

43
import gi
5-
gi.require_version("Gst", "1.0")
6-
7-
from gi.repository import Gst
84
import pyds
5+
from pydantic import BaseModel
96

107
from src.deepstream.helpers.softmax_topk_classifier import ClassificationPrediction
118

9+
1210
class PlantEvent(BaseModel):
1311
frame_id: int
1412
plant_id: str

src/deepstream/helpers/remove_background.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import numpy as np
21
import cv2
2+
import numpy as np
33

44
MORPH_KERNEL = 34
55

6+
67
def remove_background(frame_bgr: np.ndarray) -> np.ndarray:
78
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
89
R, G, B = rgb[..., 0], rgb[..., 1], rgb[..., 2]
@@ -12,10 +13,7 @@ def remove_background(frame_bgr: np.ndarray) -> np.ndarray:
1213

1314
_, mask = cv2.threshold(exg_norm, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
1415

15-
kernel = cv2.getStructuringElement(
16-
cv2.MORPH_ELLIPSE,
17-
(MORPH_KERNEL, MORPH_KERNEL)
18-
)
16+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (MORPH_KERNEL, MORPH_KERNEL))
1917
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
2018

2119
out = np.zeros_like(frame_bgr)

src/deepstream/helpers/should_skip_frame.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1-
from typing import Any
21
from enum import Enum
2+
from typing import Any
3+
4+
import cv2
35
import gi
46
import numpy as np
57
import pyds
6-
import cv2
78

89
gi.require_version("Gst", "1.0")
910
from gi.repository import Gst
1011

1112
from src.frame_comparison.frame_change_detector import FrameChangeDetector
1213

14+
1315
class FrameProcessDecision(str, Enum):
1416
PROCESS = "process"
1517
SKIP = "skip"
1618

19+
1720
def should_skip_frame(info: Any, frame_meta: Any, batch_meta: Any, frame_change_detector: FrameChangeDetector) -> int:
1821
"""Pad probe to drop frames based on frame difference analysis."""
1922
gst_buffer = info.get_buffer()
@@ -25,12 +28,8 @@ def should_skip_frame(info: Any, frame_meta: Any, batch_meta: Any, frame_change_
2528
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
2629

2730
should_process, metrics = frame_change_detector.should_process(frame_bgr)
28-
29-
decision = (
30-
FrameProcessDecision.PROCESS
31-
if should_process
32-
else FrameProcessDecision.SKIP
33-
)
31+
32+
decision = FrameProcessDecision.PROCESS if should_process else FrameProcessDecision.SKIP
3433

3534
print(
3635
f"Frame {frame_meta.frame_num:05d}: {decision.value} "

src/deepstream/helpers/softmax_topk_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import numpy as np
21
from typing import List
2+
3+
import numpy as np
34
from pydantic import BaseModel
45

56

@@ -19,7 +20,7 @@ def predict_from_logits(self, logits: np.ndarray) -> List[ClassificationPredicti
1920
exp = np.exp(logits - np.max(logits))
2021
probs = exp / np.sum(exp)
2122

22-
top_idx = np.argsort(probs)[-self.top_k:][::-1]
23+
top_idx = np.argsort(probs)[-self.top_k :][::-1]
2324

2425
results: List[ClassificationPrediction] = []
2526
for idx in top_idx:

0 commit comments

Comments
 (0)