diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 1ed3c136ba..398d81bb1b 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -30,6 +30,7 @@ const MODELS: ModelOption[] = [ { label: 'YOLO26M', value: objectDetection.yolo26m() }, { label: 'YOLO26L', value: objectDetection.yolo26l() }, { label: 'YOLO26X', value: objectDetection.yolo26x() }, + { label: 'BlazeFace', value: objectDetection.blazeface() }, ]; import ErrorBanner from '../../components/ErrorBanner'; diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index 99fe0b1ac7..0066b36396 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -45,6 +45,7 @@ type ModelId = | 'objectDetectionSsdlite' | 'objectDetectionRfdetr' | 'objectDetectionYolo26n' + | 'objectDetectionBlazeface' | 'segmentationDeeplabResnet50' | 'segmentationDeeplabResnet101' | 'segmentationDeeplabMobilenet' @@ -105,6 +106,7 @@ const TASKS: Task[] = [ { id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' }, { id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' }, { id: 'objectDetectionYolo26n', label: 'YOLO26N' }, + { id: 'objectDetectionBlazeface', label: 'BlazeFace' }, ], }, { @@ -270,6 +272,7 @@ export default function VisionCameraScreen() { | 'objectDetectionSsdlite' | 'objectDetectionRfdetr' | 'objectDetectionYolo26n' + | 'objectDetectionBlazeface' } /> )} diff --git a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx index e05de26105..9a2f2d6577 100644 --- a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx +++ b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx @@ -8,6 +8,7 @@ import { useObjectDetection, CocoLabel, CocoLabelYolo, + BlazeFaceLabel, } from 'react-native-executorch'; import BoundingBoxes from '../../BoundingBoxes'; import { FRAME_TARGET_RESOLUTION, TaskProps } from './types'; @@ -16,7 +17,8 @@ const objectDetection = models.object_detection; type ObjModelId = | 'objectDetectionSsdlite' | 'objectDetectionRfdetr' - | 'objectDetectionYolo26n'; + | 'objectDetectionYolo26n' + | 'objectDetectionBlazeface'; type Props = TaskProps & { activeModel: ObjModelId }; @@ -44,13 +46,18 @@ export default function ObjectDetectionTask({ model: objectDetection.yolo26n(), preventLoad: activeModel !== 'objectDetectionYolo26n', }); + const blazeface = useObjectDetection({ + model: objectDetection.blazeface(), + preventLoad: activeModel !== 'objectDetectionBlazeface', + }); - const active = - activeModel === 'objectDetectionSsdlite' - ? ssdlite - : activeModel === 'objectDetectionRfdetr' - ? rfdetr - : yolo26n; + const detectors = { + objectDetectionSsdlite: ssdlite, + objectDetectionRfdetr: rfdetr, + objectDetectionYolo26n: yolo26n, + objectDetectionBlazeface: blazeface, + } satisfies Record; + const active = detectors[activeModel]; type CommonDetection = Omit & { label: string }; @@ -80,7 +87,8 @@ export default function ObjectDetectionTask({ (p: { results: | Detection[] - | Detection[]; + | Detection[] + | Detection[]; imageWidth: number; imageHeight: number; }) => { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp index 24ef7a8e22..e8eca0c54d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp @@ -57,7 +57,8 @@ TensorPtr BaseInstanceSegmentation::buildInputTensor(const cv::Mat &image) { std::vector BaseInstanceSegmentation::runInference( const cv::Mat &image, double confidenceThreshold, double iouThreshold, int32_t maxInstances, const std::vector &classIndices, - bool returnMaskAtOriginalResolution, const std::string &methodName) { + bool returnMaskAtOriginalResolution, const std::string &methodName, + bool useWeightedNms) { std::scoped_lock lock(inference_mutex_); @@ -86,34 +87,37 @@ std::vector BaseInstanceSegmentation::runInference( auto instances = collectInstances( forwardResult.get(), originalSize, modelInputSize, confidenceThreshold, classIndices, returnMaskAtOriginalResolution); - return finalizeInstances(std::move(instances), iouThreshold, maxInstances); + return finalizeInstances(std::move(instances), iouThreshold, maxInstances, + useWeightedNms); } std::vector BaseInstanceSegmentation::generateFromString( std::string imageSource, double confidenceThreshold, double iouThreshold, int32_t maxInstances, std::vector classIndices, - bool returnMaskAtOriginalResolution, std::string methodName) { + bool returnMaskAtOriginalResolution, std::string methodName, + bool useWeightedNms) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); return runInference(imageRGB, confidenceThreshold, iouThreshold, maxInstances, - classIndices, returnMaskAtOriginalResolution, methodName); + classIndices, returnMaskAtOriginalResolution, methodName, + useWeightedNms); } std::vector BaseInstanceSegmentation::generateFromFrame( jsi::Runtime &runtime, const jsi::Value &frameData, double confidenceThreshold, double iouThreshold, int32_t maxInstances, std::vector classIndices, bool returnMaskAtOriginalResolution, - std::string methodName) { + std::string methodName, bool useWeightedNms) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); cv::Mat rotated = utils::rotateFrameForModel(frame, orient); - auto instances = - runInference(rotated, confidenceThreshold, iouThreshold, maxInstances, - classIndices, returnMaskAtOriginalResolution, methodName); + auto instances = runInference( + rotated, confidenceThreshold, iouThreshold, maxInstances, classIndices, + returnMaskAtOriginalResolution, methodName, useWeightedNms); for (auto &inst : instances) { utils::inverseRotateBbox(inst.bbox, orient, rotated.size()); // Inverse-rotate the mask to match the screen orientation @@ -131,11 +135,13 @@ std::vector BaseInstanceSegmentation::generateFromFrame( std::vector BaseInstanceSegmentation::generateFromPixels( JSTensorViewIn tensorView, double confidenceThreshold, double iouThreshold, int32_t maxInstances, std::vector classIndices, - bool returnMaskAtOriginalResolution, std::string methodName) { + bool returnMaskAtOriginalResolution, std::string methodName, + bool useWeightedNms) { cv::Mat image = extractFromPixels(tensorView); return runInference(image, confidenceThreshold, iouThreshold, maxInstances, - classIndices, returnMaskAtOriginalResolution, methodName); + classIndices, returnMaskAtOriginalResolution, methodName, + useWeightedNms); } std::tuple @@ -296,11 +302,14 @@ void BaseInstanceSegmentation::ensureMethodLoaded( std::vector BaseInstanceSegmentation::finalizeInstances( std::vector instances, double iouThreshold, - int32_t maxInstances) const { + int32_t maxInstances, bool useWeightedNms) const { if (applyNMS_) { - instances = - utils::computer_vision::nonMaxSuppression(instances, iouThreshold); + instances = useWeightedNms + ? utils::computer_vision::weightedNonMaxSuppression( + instances, iouThreshold) + : utils::computer_vision::nonMaxSuppression(instances, + iouThreshold); } if (std::cmp_greater(instances.size(), maxInstances)) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h index 341d0f2235..1b511d8d6e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h @@ -28,30 +28,32 @@ class BaseInstanceSegmentation : public VisionModel { double iouThreshold, int32_t maxInstances, std::vector classIndices, bool returnMaskAtOriginalResolution, - std::string methodName); + std::string methodName, bool useWeightedNms); [[nodiscard("Registered non-void function")]] std::vector generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double confidenceThreshold, double iouThreshold, int32_t maxInstances, std::vector classIndices, - bool returnMaskAtOriginalResolution, - std::string methodName); + bool returnMaskAtOriginalResolution, std::string methodName, + bool useWeightedNms); [[nodiscard("Registered non-void function")]] std::vector generateFromPixels(JSTensorViewIn tensorView, double confidenceThreshold, double iouThreshold, int32_t maxInstances, std::vector classIndices, bool returnMaskAtOriginalResolution, - std::string methodName); + std::string methodName, bool useWeightedNms); protected: cv::Size modelInputSize() const override; private: - std::vector runInference( - const cv::Mat &image, double confidenceThreshold, double iouThreshold, - int32_t maxInstances, const std::vector &classIndices, - bool returnMaskAtOriginalResolution, const std::string &methodName); + std::vector + runInference(const cv::Mat &image, double confidenceThreshold, + double iouThreshold, int32_t maxInstances, + const std::vector &classIndices, + bool returnMaskAtOriginalResolution, + const std::string &methodName, bool useWeightedNms); TensorPtr buildInputTensor(const cv::Mat &image); @@ -89,7 +91,7 @@ class BaseInstanceSegmentation : public VisionModel { std::vector finalizeInstances(std::vector instances, double iouThreshold, - int32_t maxInstances) const; + int32_t maxInstances, bool useWeightedNms) const; cv::Mat processMaskFromLogits( const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 24c4e1083a..6c18183ea9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -82,16 +82,10 @@ std::set ObjectDetection::prepareAllowedClasses( return allowedClasses; } -std::vector -ObjectDetection::postprocess(const std::vector &tensors, - cv::Size originalSize, double detectionThreshold, - double iouThreshold, - const std::vector &classIndices) { - const cv::Size inputSize = modelInputSize(); - float widthRatio = static_cast(originalSize.width) / inputSize.width; - float heightRatio = - static_cast(originalSize.height) / inputSize.height; - +std::vector ObjectDetection::postprocess( + const std::vector &tensors, const BoxTransform &transform, + double detectionThreshold, double iouThreshold, + const std::vector &classIndices, bool useWeightedNms) { // Prepare allowed classes set for filtering auto allowedClasses = prepareAllowedClasses(classIndices); @@ -124,10 +118,13 @@ ObjectDetection::postprocess(const std::vector &tensors, continue; } - float x1 = bboxes[i * 4] * widthRatio; - float y1 = bboxes[i * 4 + 1] * heightRatio; - float x2 = bboxes[i * 4 + 2] * widthRatio; - float y2 = bboxes[i * 4 + 3] * heightRatio; + // Map model-input pixel coords back to source-image coords. The same + // affine `x_src = x_model * scale + offset` works for stretch and + // letterbox preprocessing — offsets are zero in the stretch case. + float x1 = bboxes[i * 4] * transform.scaleX + transform.offsetX; + float y1 = bboxes[i * 4 + 1] * transform.scaleY + transform.offsetY; + float x2 = bboxes[i * 4 + 2] * transform.scaleX + transform.offsetX; + float y2 = bboxes[i * 4 + 3] * transform.scaleY + transform.offsetY; if (std::cmp_greater_equal(labelIdx, labelNames_.size())) { throw RnExecutorchError( @@ -140,12 +137,17 @@ ObjectDetection::postprocess(const std::vector &tensors, labelNames_[labelIdx], labelIdx, scores[i]); } - return utils::computer_vision::nonMaxSuppression(detections, iouThreshold); + return useWeightedNms + ? utils::computer_vision::weightedNonMaxSuppression(detections, + iouThreshold) + : utils::computer_vision::nonMaxSuppression(detections, + iouThreshold); } std::vector ObjectDetection::runInference( cv::Mat image, double detectionThreshold, double iouThreshold, - const std::vector &classIndices, const std::string &methodName) { + const std::vector &classIndices, const std::string &methodName, + bool useWeightedNms, bool useLetterbox) { if (detectionThreshold < 0.0 || detectionThreshold > 1.0) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "detectionThreshold must be in range [0, 1]"); @@ -171,7 +173,38 @@ std::vector ObjectDetection::runInference( } modelInputShape_ = inputShapes[0]; - cv::Mat preprocessed = preprocess(image); + const cv::Size inputSize = modelInputSize(); + cv::Mat preprocessed; + BoxTransform transform; + if (useLetterbox) { + // Aspect-preserving fit + center-pad with black bars. Models trained on + // natural-aspect crops (BlazeFace) need this — plain cv::resize stretches + // the face and shifts where anchors fire. + const float fitScale = + std::min(static_cast(inputSize.width) / originalSize.width, + static_cast(inputSize.height) / originalSize.height); + const int newW = + static_cast(std::round(originalSize.width * fitScale)); + const int newH = + static_cast(std::round(originalSize.height * fitScale)); + const int padX = (inputSize.width - newW) / 2; + const int padY = (inputSize.height - newH) / 2; + + cv::Mat resized; + cv::resize(image, resized, cv::Size(newW, newH), 0, 0, cv::INTER_AREA); + cv::copyMakeBorder(resized, preprocessed, padY, + inputSize.height - newH - padY, padX, + inputSize.width - newW - padX, cv::BORDER_CONSTANT, + cv::Scalar(0, 0, 0)); + + const float inv = 1.0f / fitScale; + transform = {inv, inv, -padX * inv, -padY * inv}; + } else { + preprocessed = preprocess(image); + transform = {static_cast(originalSize.width) / inputSize.width, + static_cast(originalSize.height) / inputSize.height, + 0.0f, 0.0f}; + } auto inputTensor = (normMean_ && normStd_) @@ -188,31 +221,34 @@ std::vector ObjectDetection::runInference( "Ensure the model input is correct."); } - return postprocess(executeResult.get(), originalSize, detectionThreshold, - iouThreshold, classIndices); + return postprocess(executeResult.get(), transform, detectionThreshold, + iouThreshold, classIndices, useWeightedNms); } std::vector ObjectDetection::generateFromString( std::string imageSource, double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName) { + std::vector classIndices, std::string methodName, + bool useWeightedNms, bool useLetterbox) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices, - methodName); + methodName, useWeightedNms, useLetterbox); } std::vector ObjectDetection::generateFromFrame( jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName) { + std::vector classIndices, std::string methodName, + bool useWeightedNms, bool useLetterbox) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient); - auto detections = runInference(rotated, detectionThreshold, iouThreshold, - classIndices, methodName); + auto detections = + runInference(rotated, detectionThreshold, iouThreshold, classIndices, + methodName, useWeightedNms, useLetterbox); for (auto &det : detections) { ::rnexecutorch::utils::inverseRotateBbox(det.bbox, orient, rotated.size()); @@ -222,10 +258,11 @@ std::vector ObjectDetection::generateFromFrame( std::vector ObjectDetection::generateFromPixels( JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName) { + std::vector classIndices, std::string methodName, + bool useWeightedNms, bool useLetterbox) { cv::Mat image = extractFromPixels(pixelData); return runInference(image, detectionThreshold, iouThreshold, classIndices, - methodName); + methodName, useWeightedNms, useLetterbox); } } // namespace rnexecutorch::models::object_detection diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index 6e3c01356e..7fcb960cea 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -14,6 +14,16 @@ namespace models::object_detection { using executorch::extension::TensorPtr; using executorch::runtime::EValue; +/// Affine transform from model-input pixel coords back to source-image coords: +/// `x_src = x_model * scaleX + offsetX`. Covers both plain stretch (offsets +/// zero) and letterbox (offsets carry the centre-pad). +struct BoxTransform { + float scaleX; + float scaleY; + float offsetX; + float offsetY; +}; + /** * @brief Object detection model that detects and localises objects in images. * @@ -75,15 +85,18 @@ class ObjectDetection : public VisionModel { [[nodiscard("Registered non-void function")]] std::vector generateFromString(std::string imageSource, double detectionThreshold, double iouThreshold, std::vector classIndices, - std::string methodName); + std::string methodName, bool useWeightedNms, + bool useLetterbox); [[nodiscard("Registered non-void function")]] std::vector generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName); + std::vector classIndices, std::string methodName, + bool useWeightedNms, bool useLetterbox); [[nodiscard("Registered non-void function")]] std::vector generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold, std::vector classIndices, - std::string methodName); + std::string methodName, bool useWeightedNms, + bool useLetterbox); protected: /** @@ -99,7 +112,8 @@ class ObjectDetection : public VisionModel { std::vector runInference(cv::Mat image, double detectionThreshold, double iouThreshold, const std::vector &classIndices, - const std::string &methodName); + const std::string &methodName, bool useWeightedNms, + bool useLetterbox); private: /** @@ -121,9 +135,9 @@ class ObjectDetection : public VisionModel { * the size of @ref labelNames_. */ std::vector - postprocess(const std::vector &tensors, cv::Size originalSize, + postprocess(const std::vector &tensors, const BoxTransform &transform, double detectionThreshold, double iouThreshold, - const std::vector &classIndices); + const std::vector &classIndices, bool useWeightedNms); /** * @brief Ensures the specified method is loaded, unloading any previous diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp index ff003eb62d..255f782a0a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp @@ -33,7 +33,7 @@ template <> struct ModelTraits { static void callGenerate(ModelType &model) { (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, 100, {}, true, - kMethodName); + kMethodName, false); } }; } // namespace model_tests @@ -51,16 +51,17 @@ TEST(InstanceSegGenerateTests, InvalidImagePathThrows) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, - 100, {}, true, kMethodName), + 100, {}, true, kMethodName, + false), RnExecutorchError); } TEST(InstanceSegGenerateTests, EmptyImagePathThrows) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); - EXPECT_THROW( - (void)model.generateFromString("", 0.5, 0.5, 100, {}, true, kMethodName), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.5, 100, {}, true, + kMethodName, false), + RnExecutorchError); } TEST(InstanceSegGenerateTests, EmptyMethodNameThrows) { @@ -75,7 +76,8 @@ TEST(InstanceSegGenerateTests, NegativeConfidenceThrows) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, - 100, {}, true, kMethodName), + 100, {}, true, kMethodName, + false), RnExecutorchError); } @@ -83,7 +85,8 @@ TEST(InstanceSegGenerateTests, ConfidenceAboveOneThrows) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, - 100, {}, true, kMethodName), + 100, {}, true, kMethodName, + false), RnExecutorchError); } @@ -91,7 +94,8 @@ TEST(InstanceSegGenerateTests, NegativeIouThresholdThrows) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1, - 100, {}, true, kMethodName), + 100, {}, true, kMethodName, + false), RnExecutorchError); } @@ -99,7 +103,8 @@ TEST(InstanceSegGenerateTests, IouThresholdAboveOneThrows) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1, - 100, {}, true, kMethodName), + 100, {}, true, kMethodName, + false), RnExecutorchError); } @@ -107,7 +112,7 @@ TEST(InstanceSegGenerateTests, ValidImageReturnsResults) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + {}, true, kMethodName, false); EXPECT_FALSE(results.empty()); } @@ -115,9 +120,9 @@ TEST(InstanceSegGenerateTests, HighThresholdReturnsFewerResults) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto lowResults = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 100, - {}, true, kMethodName); - auto highResults = model.generateFromString(kValidTestImagePath, 0.9, 0.5, - 100, {}, true, kMethodName); + {}, true, kMethodName, false); + auto highResults = model.generateFromString( + kValidTestImagePath, 0.9, 0.5, 100, {}, true, kMethodName, false); EXPECT_GE(lowResults.size(), highResults.size()); } @@ -125,7 +130,7 @@ TEST(InstanceSegGenerateTests, MaxInstancesLimitsResults) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto results = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 2, {}, - true, kMethodName); + true, kMethodName, false); EXPECT_LE(results.size(), 2u); } @@ -136,7 +141,7 @@ TEST(InstanceSegResultTests, InstancesHaveValidBoundingBoxes) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + {}, true, kMethodName, false); for (const auto &inst : results) { EXPECT_LE(inst.bbox.p1.x, inst.bbox.p2.x); @@ -150,7 +155,7 @@ TEST(InstanceSegResultTests, InstancesHaveValidScores) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + {}, true, kMethodName, false); for (const auto &inst : results) { EXPECT_GE(inst.score, 0.0f); @@ -162,7 +167,7 @@ TEST(InstanceSegResultTests, InstancesHaveValidMasks) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + {}, true, kMethodName, false); for (const auto &inst : results) { EXPECT_GT(inst.maskWidth, 0); @@ -181,7 +186,7 @@ TEST(InstanceSegResultTests, InstancesHaveValidClassIndices) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + {}, true, kMethodName, false); for (const auto &inst : results) { EXPECT_GE(inst.classIndex, 0); @@ -197,8 +202,9 @@ TEST(InstanceSegFilterTests, ClassFilterReturnsOnlyMatchingClasses) { nullptr); // Filter to class index 0 (PERSON in CocoLabelYolo) std::vector classIndices = {0}; - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - classIndices, true, kMethodName); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, classIndices, + true, kMethodName, false); for (const auto &inst : results) { EXPECT_EQ(inst.classIndex, 0); @@ -209,11 +215,11 @@ TEST(InstanceSegFilterTests, EmptyFilterReturnsAllClasses) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto allResults = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + {}, true, kMethodName, false); EXPECT_FALSE(allResults.empty()); auto noResults = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {50}, true, kMethodName); + {50}, true, kMethodName, false); EXPECT_TRUE(noResults.empty()); } @@ -224,9 +230,9 @@ TEST(InstanceSegMaskTests, LowResMaskIsSmallerThanOriginal) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto hiRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, - true, kMethodName); + true, kMethodName, false); auto loRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, - false, kMethodName); + false, kMethodName, false); if (!hiRes.empty() && !loRes.empty()) { EXPECT_LE(loRes[0].mask->size(), hiRes[0].mask->size()); @@ -243,9 +249,9 @@ TEST(InstanceSegNMSTests, NMSEnabledReturnsFewerOrEqualResults) { false, nullptr); auto nmsResults = modelWithNMS.generateFromString( - kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); + kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName, false); auto noNmsResults = modelWithoutNMS.generateFromString( - kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); + kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName, false); EXPECT_LE(nmsResults.size(), noNmsResults.size()); } @@ -262,7 +268,7 @@ TEST(InstanceSegPixelTests, ValidPixelDataReturnsResults) { {height, width, channels}, executorch::aten::ScalarType::Byte}; auto results = model.generateFromPixels(tensorView, 0.3, 0.5, 100, {}, true, - kMethodName); + kMethodName, false); EXPECT_GE(results.size(), 0u); } @@ -275,7 +281,7 @@ TEST(InstanceSegPixelTests, NegativeConfidenceThrows) { {height, width, channels}, executorch::aten::ScalarType::Byte}; EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.5, 100, {}, - true, kMethodName), + true, kMethodName, false), RnExecutorchError); } @@ -288,7 +294,7 @@ TEST(InstanceSegPixelTests, ConfidenceAboveOneThrows) { {height, width, channels}, executorch::aten::ScalarType::Byte}; EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.5, 100, {}, - true, kMethodName), + true, kMethodName, false), RnExecutorchError); } @@ -307,14 +313,14 @@ TEST(InstanceSegInheritedTests, GetInputShapeWorks) { TEST(InstanceSegInheritedTests, GetAllInputShapesWorks) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); - auto shapes = model.getAllInputShapes(kMethodName); + auto shapes = model.getAllInputShapes(kMethodName, false); EXPECT_FALSE(shapes.empty()); } TEST(InstanceSegInheritedTests, GetMethodMetaWorks) { BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); - auto result = model.getMethodMeta(kMethodName); + auto result = model.getMethodMeta(kMethodName, false); EXPECT_TRUE(result.ok()); } @@ -333,6 +339,6 @@ TEST(InstanceSegNormTests, ValidNormParamsGenerateSucceeds) { const std::vector std = {0.229f, 0.224f, 0.225f}; BaseInstanceSegmentation model(kValidInstanceSegModelPath, mean, std, true, nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, - 100, {}, true, kMethodName)); + EXPECT_NO_THROW((void)model.generateFromString( + kValidTestImagePath, 0.5, 0.5, 100, {}, true, kMethodName, false)); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp index 5c5bb6e736..98fa6eff68 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp @@ -51,7 +51,7 @@ template <> struct ModelTraits { static void callGenerate(ModelType &model) { (void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, - "forward"); + "forward", false, false); } }; } // namespace model_tests @@ -69,14 +69,16 @@ TEST(ObjectDetectionGenerateTests, InvalidImagePathThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, - 0.55, {}, "forward"), + 0.55, {}, "forward", false, + false), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, EmptyImagePathThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW((void)model.generateFromString("", 0.5, 0.55, {}, "forward"), + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.55, {}, "forward", + false, false), RnExecutorchError); } @@ -84,7 +86,8 @@ TEST(ObjectDetectionGenerateTests, MalformedURIThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, - 0.55, {}, "forward"), + 0.55, {}, "forward", false, + false), RnExecutorchError); } @@ -92,7 +95,7 @@ TEST(ObjectDetectionGenerateTests, NegativeThresholdThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.55, - {}, "forward"), + {}, "forward", false, false), RnExecutorchError); } @@ -100,33 +103,33 @@ TEST(ObjectDetectionGenerateTests, ThresholdAboveOneThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.55, - {}, "forward"), + {}, "forward", false, false), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, ValidImageReturnsResults) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, + "forward", false, false); EXPECT_GE(results.size(), 0u); } TEST(ObjectDetectionGenerateTests, HighThresholdReturnsFewerResults) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto lowThresholdResults = - model.generateFromString(kValidTestImagePath, 0.1, 0.55, {}, "forward"); - auto highThresholdResults = - model.generateFromString(kValidTestImagePath, 0.9, 0.55, {}, "forward"); + auto lowThresholdResults = model.generateFromString( + kValidTestImagePath, 0.1, 0.55, {}, "forward", false, false); + auto highThresholdResults = model.generateFromString( + kValidTestImagePath, 0.9, 0.55, {}, "forward", false, false); EXPECT_GE(lowThresholdResults.size(), highThresholdResults.size()); } TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, + "forward", false, false); for (const auto &detection : results) { EXPECT_LE(detection.bbox.p1.x, detection.bbox.p2.x); @@ -139,8 +142,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) { TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, + "forward", false, false); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -151,8 +154,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) { TEST(ObjectDetectionGenerateTests, DetectionsHaveValidLabels) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, + "forward", false, false); for (const auto &detection : results) { const auto &label = detection.label; @@ -173,7 +176,8 @@ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) { JSTensorViewIn tensorView{pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; - auto results = model.generateFromPixels(tensorView, 0.3, 0.55, {}, "forward"); + auto results = model.generateFromPixels(tensorView, 0.3, 0.55, {}, "forward", + false, false); EXPECT_GE(results.size(), 0u); } @@ -185,9 +189,9 @@ TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) { JSTensorViewIn tensorView{pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW( - (void)model.generateFromPixels(tensorView, -0.1, 0.55, {}, "forward"), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.55, {}, + "forward", false, false), + RnExecutorchError); } TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) { @@ -198,9 +202,9 @@ TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) { JSTensorViewIn tensorView{pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW( - (void)model.generateFromPixels(tensorView, 1.1, 0.55, {}, "forward"), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.55, {}, + "forward", false, false), + RnExecutorchError); } TEST(ObjectDetectionInheritedTests, GetInputShapeWorks) { @@ -255,7 +259,7 @@ TEST(ObjectDetectionNormTests, ValidNormParamsGenerateSucceeds) { ObjectDetection model(kValidObjectDetectionModelPath, mean, std, kCocoLabels, nullptr); EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, - {}, "forward")); + {}, "forward", false, false)); } // ============================================================================ @@ -265,16 +269,16 @@ TEST(ObjectDetectionMethodTests, InvalidMethodNameThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, - {}, "forward_999"), + {}, "forward_999", false, false), RnExecutorchError); } TEST(ObjectDetectionMethodTests, EmptyMethodNameThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW( - (void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, ""), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, + {}, "", false, false), + RnExecutorchError); } // ============================================================================ @@ -285,8 +289,8 @@ TEST(ObjectDetectionClassFilterTests, ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); // Only request "person" class (index 0 in COCO) - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward"); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, + "forward", false, false); for (const auto &det : results) { EXPECT_EQ(det.label, "person"); } @@ -296,11 +300,11 @@ TEST(ObjectDetectionClassFilterTests, EmptyClassIndicesReturnsMoreOrEqualResults) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto allClasses = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + auto allClasses = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, + "forward", false, false); // person (0) only - auto filtered = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward"); + auto filtered = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, + "forward", false, false); EXPECT_GE(allClasses.size(), filtered.size()); } @@ -311,10 +315,10 @@ TEST(ObjectDetectionIouTests, HigherIouThresholdReturnsSameOrMoreResults) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); // High IoU threshold = less aggressive NMS = more boxes survive - auto highIou = - model.generateFromString(kValidTestImagePath, 0.3, 0.9, {}, "forward"); + auto highIou = model.generateFromString(kValidTestImagePath, 0.3, 0.9, {}, + "forward", false, false); // Low IoU threshold = more aggressive NMS = fewer boxes survive - auto lowIou = - model.generateFromString(kValidTestImagePath, 0.3, 0.1, {}, "forward"); + auto lowIou = model.generateFromString(kValidTestImagePath, 0.3, 0.1, {}, + "forward", false, false); EXPECT_GE(highIou.size(), lowIou.size()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h index 3bd3022d4a..5d88a202dd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h @@ -48,4 +48,69 @@ std::vector nonMaxSuppression(std::vector items, double iouThreshold) { return result; } +// Weighted (blending) NMS, used by BlazeFace-style detectors. Overlapping +// anchors of the same class are score-weighted-averaged into a single box +// instead of greedily pruned to the top scorer. The output box position is +// `sum(box_i * score_i) / sum(score_i)` (paper § 3.2). The output score is +// the *max* of the cluster, not the mean: with a low pre-NMS threshold the +// mean drifts around the cluster floor and makes the detection flicker +// in/out as low-confidence anchors enter/leave the cluster between frames. +template +std::vector weightedNonMaxSuppression(std::vector items, + double iouThreshold) { + if (items.empty()) { + return {}; + } + + std::ranges::sort(items, + [](const T &a, const T &b) { return a.score > b.score; }); + + std::vector result; + std::vector consumed(items.size(), false); + + for (size_t i = 0; i < items.size(); ++i) { + if (consumed[i]) { + continue; + } + consumed[i] = true; + + float totalScore = items[i].score; + float wx1 = items[i].bbox.p1.x * items[i].score; + float wy1 = items[i].bbox.p1.y * items[i].score; + float wx2 = items[i].bbox.p2.x * items[i].score; + float wy2 = items[i].bbox.p2.y * items[i].score; + + for (size_t j = i + 1; j < items.size(); ++j) { + if (consumed[j]) { + continue; + } + + if constexpr (requires(T t) { t.classIndex; }) { + if (items[i].classIndex != items[j].classIndex) { + continue; + } + } + + float iou = computeIoU(items[i].bbox, items[j].bbox); + if (iou > iouThreshold) { + consumed[j] = true; + totalScore += items[j].score; + wx1 += items[j].bbox.p1.x * items[j].score; + wy1 += items[j].bbox.p1.y * items[j].score; + wx2 += items[j].bbox.p2.x * items[j].score; + wy2 += items[j].bbox.p2.y * items[j].score; + } + } + + T blended = items[i]; + if (totalScore > 0.0f) { + blended.bbox.p1 = {wx1 / totalScore, wy1 / totalScore}; + blended.bbox.p2 = {wx2 / totalScore, wy2 / totalScore}; + } + result.push_back(blended); + } + + return result; +} + } // namespace rnexecutorch::utils::computer_vision diff --git a/packages/react-native-executorch/src/constants/commonVision.ts b/packages/react-native-executorch/src/constants/commonVision.ts index 6221d5701e..c70efc2f09 100644 --- a/packages/react-native-executorch/src/constants/commonVision.ts +++ b/packages/react-native-executorch/src/constants/commonVision.ts @@ -211,3 +211,14 @@ export enum CocoLabelYolo { export enum FastSAMLabel { OBJECT = 0, } + +/** + * Class label for BlazeFace face detection. + * + * BlazeFace is a single-class face detector. The exported model emits a flat + * class tensor of zeros for every anchor. + * @category Types + */ +export enum BlazeFaceLabel { + FACE = 0, +} diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts index 9c9da9c420..5645def7ad 100644 --- a/packages/react-native-executorch/src/constants/modelRegistry.ts +++ b/packages/react-native-executorch/src/constants/modelRegistry.ts @@ -518,6 +518,7 @@ export const models = { yolo26m: base(M.YOLO26M), yolo26l: base(M.YOLO26L), yolo26x: base(M.YOLO26X), + blazeface: base(M.BLAZEFACE), }, pose_estimation: { yolo26n: base(M.YOLO26N_POSE), diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 4dc8966aee..0b13c8444d 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -3,7 +3,7 @@ import { PRIVACY_FILTER_NEMOTRON_LABELS, PRIVACY_FILTER_OPENAI_LABELS, } from './privacyFilterLabels'; -import { URL_PREFIX, PREVIOUS_VERSION_TAG } from './versions'; +import { URL_PREFIX, PREVIOUS_VERSION_TAG, VERSION_TAG } from './versions'; // LLMs @@ -682,6 +682,17 @@ export const YOLO26X = { modelSource: YOLO26X_DETECTION_MODEL, } as const; +// BlazeFace — pinned to VERSION_TAG (v0.10.0) where the HF repo first publishes. +const BLAZEFACE_XNNPACK_FP32_MODEL = `${URL_PREFIX}-blazeface/${VERSION_TAG}/xnnpack/blazeface.pte`; + +/** + * @category Models - Object Detection + */ +export const BLAZEFACE = { + modelName: 'blazeface', + modelSource: BLAZEFACE_XNNPACK_FP32_MODEL, +} as const; + // YOLO26 Pose Estimation const YOLO26N_POSE_MODEL = `${URL_PREFIX}-yolo26-pose/${PREVIOUS_VERSION_TAG}/xnnpack/yolo26_pose_n_xnnpack_fp32.pte`; diff --git a/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts index e7e96f2deb..3091e85ab5 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts @@ -325,6 +325,7 @@ export class InstanceSegmentationModule< this.modelConfig.defaultConfidenceThreshold ?? 0.5; const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.5; const defaultInputSize = this.modelConfig.defaultInputSize; + const useWeightedNms = this.modelConfig.nmsMode === 'weighted'; return ( frame: Frame, @@ -360,7 +361,8 @@ export class InstanceSegmentationModule< maxInstances, classIndices, returnMaskAtOriginalResolution, - methodName + methodName, + useWeightedNms ); return nativeResults.map((inst: any) => ({ bbox: inst.bbox, @@ -417,6 +419,7 @@ export class InstanceSegmentationModule< 0.5; const iouThreshold = options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.5; + const useWeightedNms = this.modelConfig.nmsMode === 'weighted'; const maxInstances = options?.maxInstances ?? 100; const returnMaskAtOriginalResolution = options?.returnMaskAtOriginalResolution ?? true; @@ -457,7 +460,8 @@ export class InstanceSegmentationModule< maxInstances, classIndices, returnMaskAtOriginalResolution, - methodName + methodName, + useWeightedNms ) : await this.nativeModule.generateFromPixels( input, @@ -466,7 +470,8 @@ export class InstanceSegmentationModule< maxInstances, classIndices, returnMaskAtOriginalResolution, - methodName + methodName, + useWeightedNms ); return nativeResult.map((inst) => ({ diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index 7274209df6..ad76114b66 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -15,6 +15,7 @@ import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError } from '../../errors/errorUtils'; import { buildLabelArray } from '../../utils/labelUtils'; import { + BlazeFaceLabel, CocoLabel, CocoLabelYolo, IMAGENET1K_MEAN, @@ -26,6 +27,8 @@ import { VisionLabeledModule, } from './VisionLabeledModule'; +const BLAZEFACE_NORM: readonly [number, number, number] = [0.5, 0.5, 0.5]; + const YOLO_DETECTION_CONFIG = { labelMap: CocoLabelYolo, preprocessorConfig: undefined, @@ -57,6 +60,16 @@ const ModelConfigs = { 'yolo26m': YOLO_DETECTION_CONFIG, 'yolo26l': YOLO_DETECTION_CONFIG, 'yolo26x': YOLO_DETECTION_CONFIG, + 'blazeface': { + labelMap: BlazeFaceLabel, + preprocessorConfig: { normMean: BLAZEFACE_NORM, normStd: BLAZEFACE_NORM }, + availableInputSizes: undefined, + defaultInputSize: undefined, + defaultDetectionThreshold: 0.5, + defaultIouThreshold: 0.3, + nmsMode: 'weighted', + useLetterbox: true, + }, } as const satisfies Record< ObjectDetectionModelName, ObjectDetectionConfig @@ -173,6 +186,8 @@ export class ObjectDetectionModule< const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.55; const defaultInputSize = this.modelConfig.defaultInputSize; const availableInputSizes = this.modelConfig.availableInputSizes; + const useWeightedNms = this.modelConfig.nmsMode === 'weighted'; + const useLetterbox = this.modelConfig.useLetterbox ?? false; return ( frame: any, @@ -215,7 +230,9 @@ export class ObjectDetectionModule< detectionThreshold, iouThreshold, classIndices, - methodName + methodName, + useWeightedNms, + useLetterbox ); }; } @@ -255,6 +272,8 @@ export class ObjectDetectionModule< const iouThreshold = options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.55; const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize; + const useWeightedNms = this.modelConfig.nmsMode === 'weighted'; + const useLetterbox = this.modelConfig.useLetterbox ?? false; // Validate inputSize against availableInputSizes if ( @@ -290,14 +309,18 @@ export class ObjectDetectionModule< detectionThreshold, iouThreshold, classIndices, - methodName + methodName, + useWeightedNms, + useLetterbox ) : await this.nativeModule.generateFromPixels( input, detectionThreshold, iouThreshold, classIndices, - methodName + methodName, + useWeightedNms, + useLetterbox ); } diff --git a/packages/react-native-executorch/src/types/common.ts b/packages/react-native-executorch/src/types/common.ts index e8afa8ff4d..777b4d5654 100644 --- a/packages/react-native-executorch/src/types/common.ts +++ b/packages/react-native-executorch/src/types/common.ts @@ -145,6 +145,7 @@ export type LabelEnum = Readonly>; * @category Types */ export type Triple = readonly [T, T, T]; + /** * Represents raw pixel data in RGB format for vision models. * diff --git a/packages/react-native-executorch/src/types/instanceSegmentation.ts b/packages/react-native-executorch/src/types/instanceSegmentation.ts index ff7f4ae314..cf57eb662b 100644 --- a/packages/react-native-executorch/src/types/instanceSegmentation.ts +++ b/packages/react-native-executorch/src/types/instanceSegmentation.ts @@ -48,8 +48,9 @@ export interface InstanceSegmentationOptions { */ confidenceThreshold?: number; /** - * IoU threshold for non-maximum suppression. - * Defaults to model's defaultIouThreshold (typically 0.5). + * IoU threshold for non-maximum suppression (0-1). Defaults to the model + * preset's `defaultIouThreshold`. Ignored for models whose preset disables + * external NMS (`postprocessorConfig.applyNMS: false`, e.g. YOLO-seg which is NMS-free) */ iouThreshold?: number; /** @@ -89,9 +90,25 @@ export type InstanceSegmentationConfig = { normMean?: Triple; normStd?: Triple; }; + /** + * `applyNMS: false` for models that produce already-deduplicated detections + * (e.g. YOLO-seg, where NMS runs inside the model graph). Property of the + * model architecture — not user-tuneable per call. + */ postprocessorConfig?: { applyNMS?: boolean }; defaultConfidenceThreshold?: number; + /** + * Default IoU threshold for non-maximum suppression (0-1). Overridable per-call + * via {@link InstanceSegmentationOptions.iouThreshold}. Has no effect when + * `postprocessorConfig.applyNMS` is `false`. + */ defaultIouThreshold?: number; + /** + * NMS algorithm baked into the model preset. Architectural — not per-call tuneable. + * - `'greedy'` (default): standard NMS, suits detectors whose anchors are independently accurate. + * - `'weighted'`: score-weighted box blending, required for ensemble-trained detectors. + */ + nmsMode?: 'greedy' | 'weighted'; } & ( | { availableInputSizes: readonly number[]; diff --git a/packages/react-native-executorch/src/types/objectDetection.ts b/packages/react-native-executorch/src/types/objectDetection.ts index 271676c439..213e7a7b60 100644 --- a/packages/react-native-executorch/src/types/objectDetection.ts +++ b/packages/react-native-executorch/src/types/objectDetection.ts @@ -36,15 +36,15 @@ export interface Detection { * Options for configuring object detection inference. * @category Types * @typeParam L - The label enum type for filtering classes of interest. - * @property {number} [detectionThreshold] - Minimum confidence score for detections (0-1). Defaults to model-specific value. - * @property {number} [iouThreshold] - IoU threshold for non-maximum suppression (0-1). Defaults to model-specific value. - * @property {number} [inputSize] - Input size for multi-method models (e.g., 384, 512, 640 for YOLO). Required for YOLO models if not using default. - * @property {(keyof L)[]} [classesOfInterest] - Optional array of class labels to filter detections. Only detections matching these classes will be returned. */ export interface ObjectDetectionOptions { + /** Minimum confidence score for detections (0-1). Defaults to the model preset's value. */ detectionThreshold?: number; + /** IoU threshold for non-maximum suppression (0-1). Defaults to the model preset's `defaultIouThreshold`. */ iouThreshold?: number; + /** Input size for multi-method models (e.g. 384/512/640 for YOLO). */ inputSize?: number; + /** Restrict output to these class labels. */ classesOfInterest?: (keyof L)[]; } @@ -60,7 +60,8 @@ export type ObjectDetectionModelSources = | { modelName: 'yolo26s'; modelSource: ResourceSource } | { modelName: 'yolo26m'; modelSource: ResourceSource } | { modelName: 'yolo26l'; modelSource: ResourceSource } - | { modelName: 'yolo26x'; modelSource: ResourceSource }; + | { modelName: 'yolo26x'; modelSource: ResourceSource } + | { modelName: 'blazeface'; modelSource: ResourceSource }; /** * Union of all built-in object detection model names. @@ -72,18 +73,28 @@ export type ObjectDetectionModelName = ObjectDetectionModelSources['modelName']; * Configuration for a custom object detection model. * @category Types * @typeParam T - The label enum type for the model. - * @property {T} labelMap - The label mapping for the model. - * @property {object} [preprocessorConfig] - Optional preprocessing configuration with normalization parameters. - * @property {number} [defaultDetectionThreshold] - Default detection confidence threshold (0-1). - * @property {number} [defaultIouThreshold] - Default IoU threshold for non-maximum suppression (0-1). - * @property {readonly number[]} [availableInputSizes] - For multi-method models, the available input sizes (e.g., [384, 512, 640]). - * @property {number} [defaultInputSize] - For multi-method models, the default input size to use. */ export type ObjectDetectionConfig = { + /** The label mapping for the model. */ labelMap: T; + /** Optional input normalisation: `(pixel - normMean) / normStd`. */ preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; + /** Default detection confidence threshold (0-1). */ defaultDetectionThreshold?: number; + /** Default IoU threshold for non-maximum suppression (0-1). Overridable per-call via {@link ObjectDetectionOptions.iouThreshold}. */ defaultIouThreshold?: number; + /** + * NMS algorithm baked into the model preset. Architectural — not per-call tuneable. + * - `'greedy'` (default): standard NMS, suits detectors whose anchors are independently accurate (YOLO, SSDLite, RF-DETR). + * - `'weighted'`: score-weighted box blending, required for ensemble-trained detectors like BlazeFace. + */ + nmsMode?: 'greedy' | 'weighted'; + /** + * Whether the model expects aspect-preserving fit + center-pad (letterbox) preprocessing + * instead of plain stretch resize. Architectural property of the model — not per-call tuneable. + * BlazeFace requires letterbox; YOLO/SSDLite/RF-DETR do not. + */ + useLetterbox?: boolean; } & ( | { availableInputSizes: readonly number[];