From 845b45ae686ddc65cf1fdfad3bd0e3a9d5cf51e8 Mon Sep 17 00:00:00 2001 From: Simon Byrne Date: Thu, 5 Jun 2025 20:49:39 +0000 Subject: [PATCH] add basic quaternion support from CUTLASS --- CMakeLists.txt | 18 +++---- cmake/GetCUTLASS.cmake | 5 +- include/matx.h | 1 + include/matx/core/print.h | 18 +++++++ include/matx/core/quaternion.h | 53 ++++++++++++++++++++ include/matx/core/type_utils.h | 18 ++++++- include/matx/transforms/matmul/matmul_cuda.h | 10 ++-- test/00_operators/operator_func_test.cu | 26 +++++++++- test/00_operators/operator_test_types.hpp | 4 ++ test/include/test_types.h | 28 ++++++++++- test/include/utilities.h | 22 ++++++++ 11 files changed, 183 insertions(+), 20 deletions(-) create mode 100644 include/matx/core/quaternion.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f07e0c10..e38543be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,7 +66,7 @@ option(MATX_BUILD_DOCS "Build documentation" OFF) option(MATX_BUILD_32_BIT "Build with 32-bit indexing support" OFF) option(MATX_MULTI_GPU "Multi-GPU support" OFF) option(MATX_EN_VISUALIZATION "Enable visualization support" OFF) -#option(MATX_EN_CUTLASS OFF) +option(MATX_EN_CUTLASS OFF) option(MATX_EN_CUTENSOR OFF) option(MATX_EN_CUDSS OFF) option(MATX_EN_FILEIO OFF) @@ -181,14 +181,14 @@ endif() set(WARN_FLAGS ${WARN_FLAGS} $<$:-Werror>) # CUTLASS slows down compile times when used, so leave it as optional for now -# if (MATX_EN_CUTLASS) -# include(cmake/GetCUTLASS.cmake) -# set (CUTLASS_INC ${cutlass_SOURCE_DIR}/include/ ${cutlass_SOURCE_DIR}/tools/util/include/) -# target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=1) -# else() -# set (CUTLASS_INC "") -# target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0) -# endif() +if (MATX_EN_CUTLASS) + include(cmake/GetCUTLASS.cmake) + set (CUTLASS_INC ${cutlass_SOURCE_DIR}/include/ ${cutlass_SOURCE_DIR}/tools/util/include/) + target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=1) +else() + set (CUTLASS_INC "") + # target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0) +endif() # CUTLASS support is not maintained. Remove the option to avoid confusion diff --git a/cmake/GetCUTLASS.cmake b/cmake/GetCUTLASS.cmake index 198765e4..89a6a290 100644 --- a/cmake/GetCUTLASS.cmake +++ b/cmake/GetCUTLASS.cmake @@ -31,8 +31,9 @@ # ///////////////////////////////////////////////////////////////////////////////// function(find_and_configure_cutlass VERSION) CPMFindPackage(NAME cutlass + VERSION ${VERSION} GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git - GIT_TAG a01feb9 + GIT_TAG v${VERSION} GIT_SHALLOW TRUE DOWNLOAD_ONLY OPTIONS "CUTLASS_ENABLE_TESTS OFF" @@ -44,5 +45,5 @@ function(find_and_configure_cutlass VERSION) endif() endfunction() -set(CUDA_MATX_MIN_VERSION_cutlass "21.08.02") +set(CUDA_MATX_MIN_VERSION_cutlass "3.9.2") find_and_configure_cutlass(${CUDA_MATX_MIN_VERSION_cutlass}) diff --git a/include/matx.h b/include/matx.h index 6d42a3aa..27d48b0f 100644 --- a/include/matx.h +++ b/include/matx.h @@ -48,6 +48,7 @@ #include "matx/core/nvtx.h" #include "matx/core/print.h" #include "matx/core/pybind.h" +#include "matx/core/quaternion.h" #include "matx/core/tensor.h" #include "matx/core/sparse_tensor.h" // sparse support is experimental #include "matx/core/make_sparse_tensor.h" diff --git a/include/matx/core/print.h b/include/matx/core/print.h index 29b46361..152f8683 100644 --- a/include/matx/core/print.h +++ b/include/matx/core/print.h @@ -54,6 +54,14 @@ namespace matx { fprintf(fp, fmt_s.c_str(), static_cast(val.real()), static_cast(val.imag())); } + else if constexpr (is_quaternion_v) { + const auto prec = std::to_string(PRINT_PRECISION); + const auto fmt_s = ("% ."s + prec + "e%+." + prec + "ei%+." + prec + "ej%+." + prec + "ek "); + fprintf(fp, fmt_s.c_str(), static_cast(val.w()), + static_cast(val.x()), + static_cast(val.y()), + static_cast(val.z())); + } else if constexpr (is_matx_half_v || is_half_v) { const auto prec = std::to_string(PRINT_PRECISION); const auto fmt_s = ("% ."s + prec + "e "); @@ -130,6 +138,16 @@ namespace matx { return "complex"; if constexpr (std::is_same_v) return "complex"; +#ifdef MATX_ENABLE_CUTLASS + if constexpr (std::is_same_v>) + return "quaternion"; + if constexpr (std::is_same_v>) + return "quaternion"; + if constexpr (std::is_same_v>) + return "quaternion"; + if constexpr (std::is_same_v>) + return "quaternion"; +#endif return "unknown"; } diff --git a/include/matx/core/quaternion.h b/include/matx/core/quaternion.h new file mode 100644 index 00000000..28d7007d --- /dev/null +++ b/include/matx/core/quaternion.h @@ -0,0 +1,53 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2021, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +// #include +// #include +// #include "matx/core/defines.h" + +#ifdef MATX_ENABLE_CUTLASS +#include "cutlass/quaternion.h" +#endif + +namespace matx { + +#ifdef MATX_ENABLE_CUTLASS +template +using quaternion = cutlass::Quaternion; + + + +#endif + +}; // namespace matx diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index 3a8cbb33..1389b191 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -45,6 +45,7 @@ #include "matx/executors/host.h" #include "matx/core/half.h" #include "matx/core/half_complex.h" +#include "matx/core/quaternion.h" /** * Defines type traits for host and device compilers. This file should be includable by @@ -469,8 +470,20 @@ template <> struct scalar_to_complex { template <> struct scalar_to_complex { using ctype = matxBf16Complex; }; -} +template struct is_quaternion : std::false_type { +}; +#ifdef MATX_ENABLE_CUTLASS +template struct is_quaternion> : std::true_type { +}; +#endif +} +/** + * @brief Determine if a type is a quaternion type (any type supported) + * + * @tparam T Type to test + */ +template inline constexpr bool is_quaternion_v = detail::is_quaternion>::value; /** * @brief Get the inner value_type of the container @@ -743,7 +756,8 @@ struct is_matx_type bool, std::is_same_v> || std::is_same_v> || std::is_same_v> || - std::is_same_v>> { + std::is_same_v> || + is_quaternion_v>> { }; } diff --git a/include/matx/transforms/matmul/matmul_cuda.h b/include/matx/transforms/matmul/matmul_cuda.h index 269ce8e3..e2003423 100644 --- a/include/matx/transforms/matmul/matmul_cuda.h +++ b/include/matx/transforms/matmul/matmul_cuda.h @@ -949,11 +949,11 @@ class MatMulCUDAHandle_t { if constexpr (RANK > MATMUL_BATCH_RANK_THRESHOLD) { if constexpr (PROV == PROVIDER_TYPE_CUTLASS) { #ifdef MATX_ENABLE_CUTLASS - for (size_t iter = 0; iter < total_iter; iter++) { + for (size_t iter = 0; iter < total_iter; iter++) { // Get pointers into A/B/C for this round - auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx); - auto bp = cuda::std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx); - auto cp = cuda::std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx); + auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, a_idx); + auto bp = cuda::std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, b_idx); + auto cp = cuda::std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, c_idx); typename CutlassGemm::Arguments args( {static_cast(params_.m), static_cast(params_.n), @@ -984,7 +984,7 @@ class MatMulCUDAHandle_t { MATX_ASSERT(status == cutlass::Status::kSuccess, matxMatMulError); // Update all but the last 2 indices - UpdateIndices(a_adj, idx, 3); + UpdateIndices(a_adj, a_idx, 3); } #else MATX_THROW(matxNotSupported, "CUTLASS not enabled!"); diff --git a/test/00_operators/operator_func_test.cu b/test/00_operators/operator_func_test.cu index 28b9547c..8817170d 100644 --- a/test/00_operators/operator_func_test.cu +++ b/test/00_operators/operator_func_test.cu @@ -446,4 +446,28 @@ TYPED_TEST(OperatorTestsComplexTypesAllExecs, OperatorFuncs) EXPECT_TRUE(MatXUtils::MatXTypeCompare(tdd0(), detail::scalar_internal_abs(c))); MATX_EXIT_HANDLER(); -} \ No newline at end of file +} + + +TYPED_TEST(OperatorTestsQuaternionTypesAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + tiv0() = c; + + // example-begin exp-test-1 + (tov0 = exp(tiv0)).run(exec); + // example-end exp-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::_internal_exp(c))); + + MATX_EXIT_HANDLER(); +} diff --git a/test/00_operators/operator_test_types.hpp b/test/00_operators/operator_test_types.hpp index 90de3a40..471cbc46 100644 --- a/test/00_operators/operator_test_types.hpp +++ b/test/00_operators/operator_test_types.hpp @@ -58,6 +58,9 @@ class OperatorTestsComplexNonHalfTypesAllExecs : public ::testing::Test {}; template class OperatorTestsComplexTypesAllExecs : public ::testing::Test {}; +template +class OperatorTestsQuaternionTypesAllExecs : public ::testing::Test {}; + template class OperatorTestsAllExecs : public ::testing::Test {}; @@ -92,6 +95,7 @@ TYPED_TEST_SUITE(OperatorTestsNumericAllExecs, TYPED_TEST_SUITE(OperatorTestsNumericNoHalfAllExecs, MatXNumericNoHalfTypesAllExecs); TYPED_TEST_SUITE(OperatorTestsComplexNonHalfTypesAllExecs, MatXComplexNonHalfTypesAllExecs); TYPED_TEST_SUITE(OperatorTestsComplexTypesAllExecs, MatXComplexTypesAllExecs); +TYPED_TEST_SUITE(OperatorTestsQuaternionTypesAllExecs, MatXQuaternionTypesAllExecs); TYPED_TEST_SUITE(OperatorTestsAllExecs, MatXAllTypesAllExecs); TYPED_TEST_SUITE(OperatorTestsFloatAllExecs, MatXTypesFloatAllExecs); TYPED_TEST_SUITE(OperatorTestsIntegralAllExecs, MatXTypesIntegralAllExecs); diff --git a/test/include/test_types.h b/test/include/test_types.h index b032ece0..bb906cbf 100644 --- a/test/include/test_types.h +++ b/test/include/test_types.h @@ -72,6 +72,25 @@ template <> auto inline GenerateData>() return cuda::std::complex(1.5, -2.5); } +#ifdef MATX_ENABLE_CUTLASS +template <> auto inline GenerateData>() +{ + return matx::quaternion(1.5, -2.5, 3.5, -4.5); +} +template <> auto inline GenerateData>() +{ + return matx::quaternion(1.5, -2.5, 3.5, -4.5); +} +template <> auto inline GenerateData>() +{ + return matx::quaternion(1.5, -2.5, 3.5, -4.5); +} +template <> auto inline GenerateData>() +{ + return matx::quaternion(1.5, -2.5, 3.5, -4.5); +} +#endif + using ExecutorTypesAll = cuda::std::tuple; using ExecutorTypesCUDAOnly = cuda::std::tuple; @@ -114,7 +133,13 @@ using MatXComplexNonHalfTuple = cuda::std::tuple; using MatXComplexTuple = cuda::std::tuple, cuda::std::complex, matx::matxFp16Complex, matx::matxBf16Complex>; - + +#ifdef MATX_ENABLE_CUTLASS +using MatXQuaternionTuple = cuda::std::tuple, matx::quaternion>; +#else +using MatXQuaternionTuple = cuda::std::tuple<>; +#endif + using MatXAllTuple = cuda::std::tuple, cuda::std::complex, matx::matxFp16Complex, @@ -161,6 +186,7 @@ using MatXFloatNonComplexNonHalfTypesAllExecs = TupleToTypes::type>::type; using MatXComplexNonHalfTypesAllExecs = TupleToTypes::type>::type; using MatXComplexTypesAllExecs = TupleToTypes::type>::type; +using MatXQuaternionTypesAllExecs = TupleToTypes::type>::type; using MatXAllTypesAllExecs = TupleToTypes::type>::type; using MatXTypesFloatNonComplexAllExecs = TupleToTypes::type>::type; using MatXTypesFloatAllExecs = TupleToTypes::type>::type; diff --git a/test/include/utilities.h b/test/include/utilities.h index 8c309b0b..38589e78 100644 --- a/test/include/utilities.h +++ b/test/include/utilities.h @@ -82,6 +82,28 @@ class MatXUtils { return ::testing::AssertionFailure(); } } + else if constexpr (matx::is_quaternion_v) { + if (fabs(static_cast(a.w()) - static_cast(b.w())) > delta) { + std::cout << "Failed in match: " << static_cast(a.w()) + << " != " << static_cast(b.w()) << "\n"; + return ::testing::AssertionFailure(); + } + if (fabs(static_cast(a.x()) - static_cast(b.x())) > delta) { + std::cout << "Failed in match: " << static_cast(a.x()) + << " != " << static_cast(b.x()) << "\n"; + return ::testing::AssertionFailure(); + } + if (fabs(static_cast(a.y()) - static_cast(b.y())) > delta) { + std::cout << "Failed in match: " << static_cast(a.y()) + << " != " << static_cast(b.y()) << "\n"; + return ::testing::AssertionFailure(); + } + if (fabs(static_cast(a.z()) - static_cast(b.z())) > delta) { + std::cout << "Failed in match: " << static_cast(a.z()) + << " != " << static_cast(b.z()) << "\n"; + return ::testing::AssertionFailure(); + } + } else if constexpr (is_matx_half_v || is_half_v) { if (fabs(static_cast(a) - static_cast(b)) > delta) { std::cout << "Failed in match: " << static_cast(a)