diff --git a/src/model/diffusion/anima.hpp b/src/model/diffusion/anima.hpp index 6042516a9..504904d41 100644 --- a/src/model/diffusion/anima.hpp +++ b/src/model/diffusion/anima.hpp @@ -227,6 +227,7 @@ namespace Anima { k4 = k_norm->forward(ctx, k4); ggml_tensor* attn_out = nullptr; + float scale = (sd_backend_is(ctx->backend, "Vulkan") && ctx->flash_attn_enabled) ? 1.0f / 32.0f : 1.0f; if (pe_q != nullptr || pe_k != nullptr) { if (pe_q == nullptr) { pe_q = pe_k; @@ -244,7 +245,8 @@ namespace Anima { num_heads, nullptr, true, - ctx->flash_attn_enabled); + ctx->flash_attn_enabled, + scale); } else { auto q_flat = ggml_reshape_3d(ctx->ggml_ctx, q4, head_dim * num_heads, L_q, N); auto k_flat = ggml_reshape_3d(ctx->ggml_ctx, k4, head_dim * num_heads, L_k, N); @@ -256,7 +258,8 @@ namespace Anima { num_heads, nullptr, false, - ctx->flash_attn_enabled); + ctx->flash_attn_enabled, + scale); } return out_proj->forward(ctx, attn_out);