Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ option(MATX_BUILD_DOCS "Build documentation" OFF)
option(MATX_BUILD_32_BIT "Build with 32-bit indexing support" OFF)
option(MATX_MULTI_GPU "Multi-GPU support" OFF)
option(MATX_EN_VISUALIZATION "Enable visualization support" OFF)
#option(MATX_EN_CUTLASS OFF)
option(MATX_EN_CUTLASS OFF)
option(MATX_EN_CUTENSOR OFF)
option(MATX_EN_CUDSS OFF)
option(MATX_EN_FILEIO OFF)
Expand Down Expand Up @@ -181,14 +181,14 @@ endif()
set(WARN_FLAGS ${WARN_FLAGS} $<$<COMPILE_LANGUAGE:CXX>:-Werror>)

# CUTLASS slows down compile times when used, so leave it as optional for now
# if (MATX_EN_CUTLASS)
# include(cmake/GetCUTLASS.cmake)
# set (CUTLASS_INC ${cutlass_SOURCE_DIR}/include/ ${cutlass_SOURCE_DIR}/tools/util/include/)
# target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=1)
# else()
# set (CUTLASS_INC "")
# target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)
# endif()
if (MATX_EN_CUTLASS)
include(cmake/GetCUTLASS.cmake)
set (CUTLASS_INC ${cutlass_SOURCE_DIR}/include/ ${cutlass_SOURCE_DIR}/tools/util/include/)
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=1)
else()
set (CUTLASS_INC "")
# target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)
endif()

# CUTLASS support is not maintained. Remove the option to avoid confusion

Expand Down
5 changes: 3 additions & 2 deletions cmake/GetCUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
# /////////////////////////////////////////////////////////////////////////////////
function(find_and_configure_cutlass VERSION)
CPMFindPackage(NAME cutlass
VERSION ${VERSION}
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG a01feb9
GIT_TAG v${VERSION}
GIT_SHALLOW TRUE
DOWNLOAD_ONLY
OPTIONS "CUTLASS_ENABLE_TESTS OFF"
Expand All @@ -44,5 +45,5 @@ function(find_and_configure_cutlass VERSION)
endif()
endfunction()

set(CUDA_MATX_MIN_VERSION_cutlass "21.08.02")
set(CUDA_MATX_MIN_VERSION_cutlass "3.9.2")
find_and_configure_cutlass(${CUDA_MATX_MIN_VERSION_cutlass})
1 change: 1 addition & 0 deletions include/matx.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "matx/core/nvtx.h"
#include "matx/core/print.h"
#include "matx/core/pybind.h"
#include "matx/core/quaternion.h"
#include "matx/core/tensor.h"
#include "matx/core/sparse_tensor.h" // sparse support is experimental
#include "matx/core/make_sparse_tensor.h"
Expand Down
18 changes: 18 additions & 0 deletions include/matx/core/print.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ namespace matx {
fprintf(fp, fmt_s.c_str(), static_cast<float>(val.real()),
static_cast<float>(val.imag()));
}
else if constexpr (is_quaternion_v<T>) {
const auto prec = std::to_string(PRINT_PRECISION);
const auto fmt_s = ("% ."s + prec + "e%+." + prec + "ei%+." + prec + "ej%+." + prec + "ek ");
fprintf(fp, fmt_s.c_str(), static_cast<float>(val.w()),
static_cast<float>(val.x()),
static_cast<float>(val.y()),
static_cast<float>(val.z()));
}
else if constexpr (is_matx_half_v<T> || is_half_v<T>) {
const auto prec = std::to_string(PRINT_PRECISION);
const auto fmt_s = ("% ."s + prec + "e ");
Expand Down Expand Up @@ -130,6 +138,16 @@ namespace matx {
return "complex<float16>";
if constexpr (std::is_same_v<T, matxBf16Complex>)
return "complex<bfloat16>";
#ifdef MATX_ENABLE_CUTLASS
if constexpr (std::is_same_v<T, matx::quaternion<float>>)
return "quaternion<float>";
if constexpr (std::is_same_v<T, matx::quaternion<double>>)
return "quaternion<double>";
if constexpr (std::is_same_v<T, matx::quaternion<matxFp16>>)
return "quaternion<float16>";
if constexpr (std::is_same_v<T, matx::quaternion<matxBf16>>)
return "quaternion<bfloat16>";
#endif

return "unknown";
}
Expand Down
53 changes: 53 additions & 0 deletions include/matx/core/quaternion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2021, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once

// #include <cuda/std/cmath>
// #include <type_traits>
// #include "matx/core/defines.h"

#ifdef MATX_ENABLE_CUTLASS
#include "cutlass/quaternion.h"
#endif

namespace matx {

#ifdef MATX_ENABLE_CUTLASS
template <typename T>
using quaternion = cutlass::Quaternion<T>;



#endif

}; // namespace matx
18 changes: 16 additions & 2 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "matx/executors/host.h"
#include "matx/core/half.h"
#include "matx/core/half_complex.h"
#include "matx/core/quaternion.h"

/**
* Defines type traits for host and device compilers. This file should be includable by
Expand Down Expand Up @@ -469,8 +470,20 @@ template <> struct scalar_to_complex<matxFp16> {
template <> struct scalar_to_complex<matxBf16> {
using ctype = matxBf16Complex;
};
}

template <typename T> struct is_quaternion : std::false_type {
};
#ifdef MATX_ENABLE_CUTLASS
template <typename T> struct is_quaternion<matx::quaternion<T>> : std::true_type {
};
#endif
}
/**
* @brief Determine if a type is a quaternion type (any type supported)
*
* @tparam T Type to test
*/
template <class T> inline constexpr bool is_quaternion_v = detail::is_quaternion<remove_cvref_t<T>>::value;

/**
* @brief Get the inner value_type of the container
Expand Down Expand Up @@ -743,7 +756,8 @@ struct is_matx_type
bool, std::is_same_v<matxFp16, std::remove_cv_t<T>> ||
std::is_same_v<matxBf16, std::remove_cv_t<T>> ||
std::is_same_v<matxFp16Complex, std::remove_cv_t<T>> ||
std::is_same_v<matxBf16Complex, std::remove_cv_t<T>>> {
std::is_same_v<matxBf16Complex, std::remove_cv_t<T>> ||
is_quaternion_v<remove_cvref_t<T>>> {
};
}

Expand Down
10 changes: 5 additions & 5 deletions include/matx/transforms/matmul/matmul_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -949,11 +949,11 @@ class MatMulCUDAHandle_t {
if constexpr (RANK > MATMUL_BATCH_RANK_THRESHOLD) {
if constexpr (PROV == PROVIDER_TYPE_CUTLASS) {
#ifdef MATX_ENABLE_CUTLASS
for (size_t iter = 0; iter < total_iter; iter++) {
for (size_t iter = 0; iter < total_iter; iter++) {
// Get pointers into A/B/C for this round
auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, idx);
auto bp = cuda::std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, idx);
auto cp = cuda::std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, idx);
auto ap = cuda::std::apply([&a_adj](auto... param) { return a_adj.GetPointer(param...); }, a_idx);
auto bp = cuda::std::apply([&b_adj](auto... param) { return b_adj.GetPointer(param...); }, b_idx);
auto cp = cuda::std::apply([&c_adj](auto... param) { return c_adj.GetPointer(param...); }, c_idx);

typename CutlassGemm::Arguments args(
{static_cast<int>(params_.m), static_cast<int>(params_.n),
Expand Down Expand Up @@ -984,7 +984,7 @@ class MatMulCUDAHandle_t {
MATX_ASSERT(status == cutlass::Status::kSuccess, matxMatMulError);

// Update all but the last 2 indices
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, idx, 3);
UpdateIndices<TensorTypeA, shape_type, TensorTypeA::Rank()>(a_adj, a_idx, 3);
}
#else
MATX_THROW(matxNotSupported, "CUTLASS not enabled!");
Expand Down
26 changes: 25 additions & 1 deletion test/00_operators/operator_func_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -446,4 +446,28 @@ TYPED_TEST(OperatorTestsComplexTypesAllExecs, OperatorFuncs)
EXPECT_TRUE(MatXUtils::MatXTypeCompare(tdd0(), detail::scalar_internal_abs(c)));

MATX_EXIT_HANDLER();
}
}


TYPED_TEST(OperatorTestsQuaternionTypesAllExecs, OperatorFuncs)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

ExecType exec{};

auto tiv0 = make_tensor<TestType>({});
auto tov0 = make_tensor<TestType>({});

TestType c = GenerateData<TestType>();
tiv0() = c;

// example-begin exp-test-1
(tov0 = exp(tiv0)).run(exec);
// example-end exp-test-1
exec.sync();
EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::_internal_exp(c)));

MATX_EXIT_HANDLER();
}
4 changes: 4 additions & 0 deletions test/00_operators/operator_test_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class OperatorTestsComplexNonHalfTypesAllExecs : public ::testing::Test {};
template <typename TensorType>
class OperatorTestsComplexTypesAllExecs : public ::testing::Test {};

template <typename TensorType>
class OperatorTestsQuaternionTypesAllExecs : public ::testing::Test {};

template <typename TensorType>
class OperatorTestsAllExecs : public ::testing::Test {};

Expand Down Expand Up @@ -92,6 +95,7 @@ TYPED_TEST_SUITE(OperatorTestsNumericAllExecs,
TYPED_TEST_SUITE(OperatorTestsNumericNoHalfAllExecs, MatXNumericNoHalfTypesAllExecs);
TYPED_TEST_SUITE(OperatorTestsComplexNonHalfTypesAllExecs, MatXComplexNonHalfTypesAllExecs);
TYPED_TEST_SUITE(OperatorTestsComplexTypesAllExecs, MatXComplexTypesAllExecs);
TYPED_TEST_SUITE(OperatorTestsQuaternionTypesAllExecs, MatXQuaternionTypesAllExecs);
TYPED_TEST_SUITE(OperatorTestsAllExecs, MatXAllTypesAllExecs);
TYPED_TEST_SUITE(OperatorTestsFloatAllExecs, MatXTypesFloatAllExecs);
TYPED_TEST_SUITE(OperatorTestsIntegralAllExecs, MatXTypesIntegralAllExecs);
Expand Down
28 changes: 27 additions & 1 deletion test/include/test_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ template <> auto inline GenerateData<cuda::std::complex<double>>()
return cuda::std::complex<double>(1.5, -2.5);
}

#ifdef MATX_ENABLE_CUTLASS
template <> auto inline GenerateData<matx::quaternion<matx::matxFp16>>()
{
return matx::quaternion<matx::matxFp16>(1.5, -2.5, 3.5, -4.5);
}
template <> auto inline GenerateData<matx::quaternion<matx::matxBf16>>()
{
return matx::quaternion<matx::matxBf16>(1.5, -2.5, 3.5, -4.5);
}
template <> auto inline GenerateData<matx::quaternion<float>>()
{
return matx::quaternion<float>(1.5, -2.5, 3.5, -4.5);
}
template <> auto inline GenerateData<matx::quaternion<double>>()
{
return matx::quaternion<double>(1.5, -2.5, 3.5, -4.5);
}
#endif

using ExecutorTypesAll = cuda::std::tuple<matx::cudaExecutor, matx::SingleThreadedHostExecutor, matx::AllThreadsHostExecutor, matx::SelectThreadsHostExecutor>;
using ExecutorTypesCUDAOnly = cuda::std::tuple<matx::cudaExecutor>;

Expand Down Expand Up @@ -114,7 +133,13 @@ using MatXComplexNonHalfTuple = cuda::std::tuple<cuda::std::compl
using MatXNumericNonComplexTuple = cuda::std::tuple<uint32_t, int32_t, uint64_t, int64_t, float, double>;
using MatXComplexTuple = cuda::std::tuple<cuda::std::complex<float>, cuda::std::complex<double>,
matx::matxFp16Complex, matx::matxBf16Complex>;


#ifdef MATX_ENABLE_CUTLASS
using MatXQuaternionTuple = cuda::std::tuple<matx::quaternion<float>, matx::quaternion<double>>;
#else
using MatXQuaternionTuple = cuda::std::tuple<>;
#endif

using MatXAllTuple = cuda::std::tuple<matx::matxFp16, matx::matxBf16, bool, uint32_t, int32_t, uint64_t,
int64_t, float, double, cuda::std::complex<float>,
cuda::std::complex<double>, matx::matxFp16Complex,
Expand Down Expand Up @@ -161,6 +186,7 @@ using MatXFloatNonComplexNonHalfTypesAllExecs = TupleToTypes<TypedCartesianProdu
using MatXNumericNoHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXNumericNonHalfTuple, ExecutorTypesAll>::type>::type;
using MatXComplexNonHalfTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXComplexNonHalfTuple, ExecutorTypesAll>::type>::type;
using MatXComplexTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXComplexTuple, ExecutorTypesAll>::type>::type;
using MatXQuaternionTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXQuaternionTuple, ExecutorTypesAll>::type>::type;
using MatXAllTypesAllExecs = TupleToTypes<TypedCartesianProduct<MatXAllTuple, ExecutorTypesAll>::type>::type;
using MatXTypesFloatNonComplexAllExecs = TupleToTypes<TypedCartesianProduct<MatXFloatNonComplexTuple, ExecutorTypesAll>::type>::type;
using MatXTypesFloatAllExecs = TupleToTypes<TypedCartesianProduct<MatXFloatTuple, ExecutorTypesAll>::type>::type;
Expand Down
22 changes: 22 additions & 0 deletions test/include/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,28 @@ class MatXUtils {
return ::testing::AssertionFailure();
}
}
else if constexpr (matx::is_quaternion_v<T1>) {
if (fabs(static_cast<float>(a.w()) - static_cast<float>(b.w())) > delta) {
std::cout << "Failed in match: " << static_cast<float>(a.w())
<< " != " << static_cast<float>(b.w()) << "\n";
return ::testing::AssertionFailure();
}
if (fabs(static_cast<float>(a.x()) - static_cast<float>(b.x())) > delta) {
std::cout << "Failed in match: " << static_cast<float>(a.x())
<< " != " << static_cast<float>(b.x()) << "\n";
return ::testing::AssertionFailure();
}
if (fabs(static_cast<float>(a.y()) - static_cast<float>(b.y())) > delta) {
std::cout << "Failed in match: " << static_cast<float>(a.y())
<< " != " << static_cast<float>(b.y()) << "\n";
return ::testing::AssertionFailure();
}
if (fabs(static_cast<float>(a.z()) - static_cast<float>(b.z())) > delta) {
std::cout << "Failed in match: " << static_cast<float>(a.z())
<< " != " << static_cast<float>(b.z()) << "\n";
return ::testing::AssertionFailure();
}
}
else if constexpr (is_matx_half_v<T1> || is_half_v<T1>) {
if (fabs(static_cast<float>(a) - static_cast<float>(b)) > delta) {
std::cout << "Failed in match: " << static_cast<float>(a)
Expand Down