diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index 71168af..688db24 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -36,36 +36,63 @@ __device__ __host__ T ScalarSaturateCast(U v) { constexpr bool bigToSmall = !smallToBig; if constexpr (std::is_integral_v && std::is_floating_point_v) { - // Any float -> any integral - return static_cast(std::clamp(std::round(v), static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()))); - } else if constexpr (std::is_integral_v && std::is_integral_v && std::is_signed_v && - std::is_unsigned_v && smallToBig) { - // Any integral signed -> Any integral unsigned, small -> big or equal - return v <= 0 ? 0 : static_cast(v); - } else if constexpr (std::is_integral_v && std::is_integral_v && - ((std::is_signed_v && std::is_signed_v) || - (std::is_unsigned_v && std::is_unsigned_v)) && - bigToSmall) { - // Any integral signed -> Any integral signed, big -> small - // Any integral unsigned -> Any integral unsigned, big -> small - return v <= std::numeric_limits::min() - ? std::numeric_limits::min() - : (v >= std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(v)); - } else if constexpr (std::is_integral_v && std::is_unsigned_v && std::is_integral_v && - std::is_signed_v) { - // Any integral unsigned -> Any integral signed, small -> big or equal - return v >= std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(v); - } else if constexpr (std::is_integral_v && std::is_signed_v && std::is_integral_v && - std::is_unsigned_v && bigToSmall) { - // Any integral signed -> Any integral unsigned, big -> small - return v <= static_cast(std::numeric_limits::min()) - ? std::numeric_limits::min() - : (v >= static_cast(std::numeric_limits::max()) ? std::numeric_limits::max() - : static_cast(v)); - } else { - // All other cases fall into this - return v; + // Float -> integral: clamp to [min, max] then round. + constexpr U minVal = static_cast(std::numeric_limits::lowest()); + constexpr U maxVal = static_cast(std::numeric_limits::max()); + + if constexpr (sizeof(T) <= 2) { + // 8/16 bit integer cases. These can be represented exactly in floating point. +#ifdef __HIP_DEVICE_COMPILE__ + return static_cast(rintf(fminf(fmaxf(v, minVal), maxVal))); +#else + return static_cast(std::round(std::clamp(v, minVal, maxVal))); +#endif + } else { + // 32/64 bit integer cases. +#ifdef __HIP_DEVICE_COMPILE__ + U rounded = rintf(v); +#else + U rounded = std::round(v); +#endif + + return rounded >= maxVal ? std::numeric_limits::max() + : rounded <= minVal ? std::numeric_limits::min() + : static_cast(rounded); + } + } + + else if constexpr (std::is_integral_v && std::is_integral_v && std::is_signed_v && std::is_unsigned_v && + smallToBig) { + // Signed -> unsigned, small to big: clamp negative to 0 + // Branchless: max(v, 0) handles negative values + return static_cast(max(v, U{0})); + } + + else if constexpr (std::is_integral_v && std::is_integral_v && + ((std::is_signed_v && std::is_signed_v) || + (std::is_unsigned_v && std::is_unsigned_v)) && + bigToSmall) { + // Same signedness, big -> small: clamp to [min, max] + constexpr U minVal = static_cast(std::numeric_limits::min()); + constexpr U maxVal = static_cast(std::numeric_limits::max()); + return static_cast(min(max(v, minVal), maxVal)); + } + + else if constexpr (std::is_integral_v && std::is_unsigned_v && std::is_integral_v && std::is_signed_v) { + // Unsigned -> signed: clamp to max (can't exceed min since unsigned) + constexpr U maxVal = static_cast(std::numeric_limits::max()); + return static_cast(min(v, maxVal)); + } + + else if constexpr (std::is_integral_v && std::is_signed_v && std::is_integral_v && std::is_unsigned_v && + bigToSmall) { + // Signed -> unsigned, big -> small: clamp to [0, max] + constexpr U maxVal = static_cast(std::numeric_limits::max()); + return static_cast(min(max(v, U{0}), maxVal)); + } + + else { + return static_cast(v); } } @@ -117,9 +144,19 @@ __device__ __host__ T ScalarRangeCast(U v) { else if constexpr (std::is_integral_v && std::is_floating_point_v && std::is_unsigned_v) { // float to unsigned integers - return v >= T{1} ? std::numeric_limits::max() - : v <= T{0} ? 0 - : static_cast(lrintf(static_cast(std::numeric_limits::max()) * v)); + constexpr U scale = static_cast(std::numeric_limits::max()); + + if constexpr (sizeof(T) <= 2) { + // 8/16 bit integer cases. These can be represented exactly in floating point. +#ifdef __HIP_DEVICE_COMPILE__ + return static_cast(__float2int_rn(__saturatef(v) * scale)); +#else + return static_cast(lrintf(fminf(fmaxf(v, 0.0f), 1.0f) * scale)); +#endif + } else { + // 32/64 bit integer cases. + return v >= U{1} ? std::numeric_limits::max() : v <= U{0} ? 0 : static_cast(std::round(v * scale)); + } } else if constexpr (std::is_floating_point_v && std::is_integral_v && std::is_signed_v) { diff --git a/include/core/detail/type_traits.hpp b/include/core/detail/type_traits.hpp index dcf77eb..32f14d5 100644 --- a/include/core/detail/type_traits.hpp +++ b/include/core/detail/type_traits.hpp @@ -20,6 +20,7 @@ */ #include + #include #pragma once @@ -83,6 +84,8 @@ DEFINE_TYPE_TRAITS_0_TO_4(int, signed int); DEFINE_TYPE_TRAITS_0_TO_4(short, signed short); DEFINE_TYPE_TRAITS_0_TO_4(ushort, unsigned short); DEFINE_TYPE_TRAITS_0_TO_4(double, double); +DEFINE_TYPE_TRAITS_0_TO_4(long, signed long); +DEFINE_TYPE_TRAITS_0_TO_4(ulong, unsigned long); /** * @brief Returns the number of elements in a HIP vectorized type. For example: uchar3 will return 3, int2 will diff --git a/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp index c284dc8..55001a0 100644 --- a/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp +++ b/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp @@ -39,6 +39,8 @@ int main(int argc, char **argv) { TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), std::numeric_limits::min())); TEST_CASE(EXPECT_EQ(RangeCast(1.0f), std::numeric_limits::max())); TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), 0)); + TEST_CASE(EXPECT_EQ(RangeCast(0.0f), 0)); + // Test unsigned/signed integer -> float casting TEST_CASE(EXPECT_EQ(RangeCast(std::numeric_limits::max()), 1.0f)); diff --git a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp new file mode 100644 index 0000000..84be5ee --- /dev/null +++ b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +#include "test_helpers.hpp" + +using namespace roccv::detail; +using namespace roccv::tests; +using namespace roccv; + +int main(int argc, char **argv) { + TEST_CASES_BEGIN(); + + TEST_CASE(EXPECT_EQ(SaturateCast(1.0f), 1)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.0f), -1)); + TEST_CASE(EXPECT_EQ(SaturateCast(1.0f), 1)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.0f), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(1), 1.0f)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1), -1.0f)); + TEST_CASE(EXPECT_EQ(SaturateCast(1), 1.0)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1), -1.0)); + + // Test numeric limits + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max())); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max())); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max())); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::lowest()), 0UL)); + + // Test vectorized types + TEST_CASE(EXPECT_TRUE((SaturateCast(uchar4{255, 128, 0, 255}) == float4{255.0f, 128.0f, 0.0f, 255.0f}))); + TEST_CASE(EXPECT_TRUE( + (SaturateCast(char4{-128, -128, -128, -128}) == float4{-128.0f, -128.0f, -128.0f, -128.0f}))); + + TEST_CASES_END(); +} \ No newline at end of file