Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3125,8 +3125,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;

for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);
uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);

if (device->architecture == AMD_GCN) {
const uint32_t max_subgroup_threads = std::min(128u, subgroup_size * 2);
const uint32_t max_subgroup16_threads = std::min(128u, subgroup_size16 * 2);

wg_size_subgroup = std::min(wg_size_subgroup, max_subgroup_threads);
wg_size_subgroup16 = std::min(wg_size_subgroup16, max_subgroup16_threads);
}

const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
(use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
Expand Down