diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ebbb412e55f..9b11565a17b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2462,18 +2462,71 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + l_align = l_wg_denoms[0]; + m_align = m_wg_denoms[0]; + s_align = s_wg_denoms[0]; + + if (device->architecture == AMD_GCN) { + const bool disable_wave64_tiles = getenv("GGML_VK_DISABLE_GCN_WAVE64") != nullptr; + if (!disable_wave64_tiles) { + const uint32_t wave_subgroup = 16; + const std::vector gcn_l_warptile = { 256, 64, 64, 16, wave_subgroup, 16, 2, tm_l, tn_l, tk_l, wave_subgroup }; + const std::vector gcn_m_warptile = { 256, 64, 64, 16, wave_subgroup, 16, 2, tm_m, tn_m, tk_m, wave_subgroup }; + const std::vector gcn_s_warptile = { 256, 64, 64, 16, wave_subgroup, 16, 2, tm_s, tn_s, tk_s, wave_subgroup }; + + const std::vector gcn_l_warptile_id = { 256, 64, 64, 16, wave_subgroup, 16, 2, tm_l, tn_l, tk_l, wave_subgroup }; + const std::vector gcn_m_warptile_id = { 256, 64, 64, 16, wave_subgroup, 16, 2, tm_m, tn_m, tk_m, wave_subgroup }; + const std::vector gcn_s_warptile_id = { 256, 64, 64, 16, wave_subgroup, 16, 2, tm_s, tn_s, tk_s, wave_subgroup }; + + const std::array gcn_l_wg_denoms = { 128, 64, 1 }; + const std::array gcn_m_wg_denoms = { 64, 64, 1 }; + const std::array gcn_s_wg_denoms = { 32, 64, 1 }; + + const auto tiles_fit_lds = [&](const std::vector &tile, bool is_id) { + return ggml_vk_matmul_shmem_support(device, tile, is_id, GGML_TYPE_F32); + }; + + if (tiles_fit_lds(gcn_s_warptile, false) && + tiles_fit_lds(gcn_m_warptile, false) && + tiles_fit_lds(gcn_l_warptile, false) && + tiles_fit_lds(gcn_s_warptile_id, true) && + tiles_fit_lds(gcn_m_warptile_id, true) && + tiles_fit_lds(gcn_l_warptile_id, true)) { + l_warptile = gcn_l_warptile; + m_warptile = gcn_m_warptile; + s_warptile = gcn_s_warptile; + + l_warptile_id = gcn_l_warptile_id; + m_warptile_id = gcn_m_warptile_id; + s_warptile_id = gcn_s_warptile_id; + + l_wg_denoms = gcn_l_wg_denoms; + m_wg_denoms = gcn_m_wg_denoms; + s_wg_denoms = gcn_s_wg_denoms; + + l_align = gcn_l_wg_denoms[0]; + m_align = gcn_m_wg_denoms[0]; + s_align = gcn_s_wg_denoms[0]; + + VK_LOG_DEBUG("ggml_vulkan: using AMD GCN wave64 matmul tiles"); + } else { + VK_LOG_DEBUG("ggml_vulkan: AMD GCN wave64 matmul tiles exceed LDS, falling back to defaults"); + } + } + } + // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; } - l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; - m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; - s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; - l_align = 128; - m_align = 64; - s_align = 32; + l_mmq_wg_denoms = l_wg_denoms; + m_mmq_wg_denoms = m_wg_denoms; + s_mmq_wg_denoms = s_wg_denoms; for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { ggml_type t = (ggml_type)i;