Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,24 @@ enum shader_reduction_mode {
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);

static bool ggml_vk_driver_has_unstable_subgroup_arithmetic(vk_device_architecture architecture, vk::DriverId driver_id) {
if (architecture != vk_device_architecture::AMD_GCN) {
return false;
}

switch (driver_id) {
case vk::DriverId::eMesaRadv:
case vk::DriverId::eAmdOpenSource:
return true;
default:
return false;
}
}

static bool ggml_vk_supports_subgroup_reduction(vk_device_architecture architecture, vk::DriverId driver_id, bool subgroup_arithmetic) {
return subgroup_arithmetic && !ggml_vk_driver_has_unstable_subgroup_arithmetic(architecture, driver_id);
}

struct vk_device_struct {
std::recursive_mutex mutex;

Expand Down Expand Up @@ -3114,7 +3132,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
rm_stdq = 2;
uint32_t rm_iq = 2 * rm_kq;

const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
const bool use_subgroups = ggml_vk_supports_subgroup_reduction(device->architecture, device->driver_id, device->subgroup_arithmetic);
// Ensure a subgroup size >= 16 is available
const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;

Expand Down Expand Up @@ -9699,6 +9717,17 @@ static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx
}

#ifdef GGML_VULKAN_RUN_TESTS
struct ggml_vk_subgroup_support_tests {
ggml_vk_subgroup_support_tests() {
GGML_ASSERT(!ggml_vk_supports_subgroup_reduction(vk_device_architecture::AMD_GCN, vk::DriverId::eMesaRadv, true));
GGML_ASSERT(!ggml_vk_supports_subgroup_reduction(vk_device_architecture::AMD_GCN, vk::DriverId::eAmdOpenSource, true));
GGML_ASSERT(ggml_vk_supports_subgroup_reduction(vk_device_architecture::AMD_GCN, vk::DriverId::eAmdProprietary, true));
GGML_ASSERT(!ggml_vk_supports_subgroup_reduction(vk_device_architecture::AMD_GCN, vk::DriverId::eAmdProprietary, false));
}
};

static ggml_vk_subgroup_support_tests ggml_vk_subgroup_support_tests_instance;

static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
return;
Expand Down