From b6ad3d7c468545c496b41f3ed450c73a88554bb5 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 9 Dec 2025 23:32:34 +0530 Subject: [PATCH 1/3] Fix GEMM dispatch for complex-real matmul --- src/matmul.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/matmul.jl b/src/matmul.jl index d3eabfda..ffee48d0 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -500,8 +500,18 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β) end # THE one big BLAS dispatch. This is split into two methods to improve latency -Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, +function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} + return _generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, val) +end +# This method is only useful for GEMM, as syrk and herk require A and B to have the same eltype +# In GEMM, however, we may reinetrpret a complex A as a real array before carrying out the matmul +function generic_matmatmul_wrapper!(C::StridedMatrix{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasReal} + return _generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, val) +end +Base.@constprop :aggressive function _generic_matmatmul_wrapper!(C::StridedMatrix{<:Number}, tA, tB, A::StridedVecOrMat{<:Number}, B::StridedVecOrMat{<:Number}, + α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) From c8bb59698f8c87aea8ccd6650021fdaccc424f9e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 10 Dec 2025 16:49:55 +0530 Subject: [PATCH 2/3] Update outdated flag in `generic_matmatmul_wrapper!` --- src/matmul.jl | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index ffee48d0..c4737c85 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -317,7 +317,7 @@ end BlasFlag.SYRK elseif (tA_uc == 'C' && tB_uc == 'N') || (tA_uc == 'N' && tB_uc == 'C') BlasFlag.HERK - else isntc + else BlasFlag.GEMM end else @@ -500,18 +500,8 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β) end # THE one big BLAS dispatch. This is split into two methods to improve latency -function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, +Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} - return _generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, val) -end -# This method is only useful for GEMM, as syrk and herk require A and B to have the same eltype -# In GEMM, however, we may reinetrpret a complex A as a real array before carrying out the matmul -function generic_matmatmul_wrapper!(C::StridedMatrix{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasReal} - return _generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, val) -end -Base.@constprop :aggressive function _generic_matmatmul_wrapper!(C::StridedMatrix{<:Number}, tA, tB, A::StridedVecOrMat{<:Number}, B::StridedVecOrMat{<:Number}, - α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) @@ -668,13 +658,9 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{true}) where {T<:BlasReal} + α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal} gemm_wrapper!(C, tA, tB, A, B, α, β) end -Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - alpha::Number, beta::Number, ::Val{false}) where {T<:BlasReal} - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) -end # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = From cbdb717777964458b62ddc1d8d8b1706ea9cca13 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 10 Dec 2025 17:00:43 +0530 Subject: [PATCH 3/3] Reorder methods --- src/matmul.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/matmul.jl b/src/matmul.jl index c4737c85..8593aecb 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -499,7 +499,7 @@ function matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α::Bool, β) return false end -# THE one big BLAS dispatch. This is split into two methods to improve latency +# THE one big BLAS dispatch. This is split into syrk/herk/gemm and symm/hemm/none methods to improve latency Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number} mA, nA = lapack_size(tA, A) @@ -511,6 +511,12 @@ Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val) return C end + +function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal} + gemm_wrapper!(C, tA, tB, A, B, α, β) +end + Base.@constprop :aggressive function _syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, ::Val{BlasFlag.SYRK}) if A === B tA_uc = uppercase(tA) # potentially strip a WrapperChar @@ -657,10 +663,6 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) -function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - α::Number, β::Number, ::Val{BlasFlag.GEMM}) where {T<:BlasReal} - gemm_wrapper!(C, tA, tB, A, B, α, β) -end # legacy method Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} =