diff --git a/CHANGELOG.md b/CHANGELOG.md index 066dc9aa3b..f881f46f6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added support for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM. diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index ebe17aadd6..016f7be60d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3 static constexpr bool check_data_type() { using Base = TestCkTileGemmPipeline>; - if constexpr(std::is_same_v && - std::is_same_v) - { - return false; - } - else if constexpr(std::is_same_v && - std::is_same_v) + if constexpr(std::is_same_v && + std::is_same_v) { return false; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 334e360eb5..4bef581254 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>