@@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
129129 return dest
130130end
131131
132- export boxdot, ⊡ , boxdot!
132+ export boxdot, ⊡ , ⊡ ₂, boxdot!
133133
134134"""
135135 boxdot(A,B) = A ⊡ B # \\ boxdot
@@ -177,40 +177,55 @@ Float64
177177```
178178See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
179179"""
180- function boxdot (A:: AbstractArray , B:: AbstractArray )
181- Amat = _squash_left (A)
182- Bmat = _squash_right (B)
180+ function boxdot (A:: AbstractArray , B:: AbstractArray , nth:: Val )
181+ _check_boxdot_axes (A, B, nth)
182+ Amat = _squash_left (A, nth)
183+ Bmat = _squash_right (B, nth)
183184
184185 axA, axB = axes (Amat,2 ), axes (Bmat,1 )
185186 axA == axB || _throw_dmm (axA, axB)
186187
187- return _boxdot_reshape (Amat * Bmat, A, B)
188+ return _boxdot_reshape (Amat * Bmat, A, B, nth )
188189end
189190
191+ boxdot (A:: AbstractArray , B:: AbstractArray ) = boxdot (A, B, Val (1 ))
192+ boxdot2 (A:: AbstractArray , B:: AbstractArray ) = boxdot (A, B, Val (2 ))
193+
190194const ⊡ = boxdot
195+ const ⊡ ₂ = boxdot2
191196
192197@noinline _throw_dmm (axA, axB) = throw (DimensionMismatch (" neighbouring axes of `A` and `B` must match, got $axA and $axB " ))
198+ @noinline _throw_boxdot_nth (n) = throw (ArgumentError (" boxdot order should be ≥ 1, got $n " ))
199+
200+ function _check_boxdot_axes (A:: AbstractArray{<:Any,N} , B:: AbstractArray{<:Any,M} , :: Val{K} ) where {N,M,K}
201+ K:: Int
202+ (K >= 1 ) || _throw_boxdot_nth (K)
203+ for i in 1 : K
204+ axA, axB = axes (A)[N- K+ i], axes (B)[i]
205+ axA == axB || _throw_dmm (axA, axB)
206+ end
207+ end
193208
194- _squash_left (A:: AbstractArray ) = reshape (A, :, size (A, ndims (A)) )
195- _squash_left (A:: AbstractMatrix ) = A
209+ _squash_left (A:: AbstractArray , :: Val{N} ) where {N} = reshape (A, prod ( size (A)[ 1 : end - N]),: )
210+ _squash_left (A:: AbstractMatrix , :: Val{1} ) = A
196211
197- _squash_right (B:: AbstractArray ) = reshape (B, size (B, 1 ),: )
198- _squash_right (B:: AbstractVecOrMat ) = B
212+ _squash_right (B:: AbstractArray , :: Val{N} ) where {N} = reshape (B, :, prod ( size (B)[ 1 + N : end ]) )
213+ _squash_right (B:: AbstractVecOrMat , :: Val{1} ) = B
199214
200- function _boxdot_reshape (AB:: AbstractArray , A:: AbstractArray{T,N} , B:: AbstractArray{S,M} ) where {T,N,S,M}
201- ax = ntuple (i -> i< N ? axes (A, i) : axes (B, i- N+ 2 ), Val (N+ M- 2 ))
215+ function _boxdot_reshape (AB:: AbstractArray , A:: AbstractArray{T,N} , B:: AbstractArray{S,M} , :: Val{K} ) where {T,N,S,M,K}
216+ N- K ≥ 1 && M- K ≥ 1 && N+ M- 2 K ≤ 2 && return AB # These can skip final reshape
217+ ax = ntuple (i -> i≤ N- K ? axes (A, i) : axes (B, i- N+ 2 K), Val (N+ M- 2 K))
202218 reshape (AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
203219end
204220
205221# These can skip final reshape:
206- _boxdot_reshape (AB:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat ) = AB
222+ _boxdot_reshape (AB:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , :: Val ) = AB
207223
208224# These produce scalar output:
209- function boxdot (A:: AbstractVector , B:: AbstractVector )
210- axA, axB = axes (A,1 ), axes (B,1 )
211- axA == axB || _throw_dmm (axA, axB)
225+ function boxdot (A:: AbstractArray{<:Any,N} , B:: AbstractArray{<:Any,N} , :: Val{N} ) where {N}
226+ _check_boxdot_axes (A, B, Val (N))
212227 if eltype (A) <: Number
213- return transpose (A) * B
228+ return transpose (vec (A)) * vec (B)
214229 else
215230 return sum (a* b for (a,b) in zip (A,B))
216231 end
@@ -224,30 +239,30 @@ boxdot(a::Number, b::Number) = a*b
224239using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
225240
226241# Adjont and Transpose, vectors or almost (returning a scalar)
227- boxdot (A:: AdjointAbsVec , B:: AbstractVector ) = A * B
228- boxdot (A:: TransposeAbsVec , B:: AbstractVector ) = A * B
242+ boxdot (A:: AdjointAbsVec , B:: AbstractVector , :: Val{1} ) = A * B
243+ boxdot (A:: TransposeAbsVec , B:: AbstractVector , :: Val{1} ) = A * B
229244
230- boxdot (A:: AbstractVector , B:: AdjointAbsVec ) = A ⊡ vec (B)
231- boxdot (A:: AbstractVector , B:: TransposeAbsVec ) = A ⊡ vec (B)
245+ boxdot (A:: AbstractVector , B:: AdjointAbsVec , :: Val{1} ) = A ⊡ vec (B)
246+ boxdot (A:: AbstractVector , B:: TransposeAbsVec , :: Val{1} ) = A ⊡ vec (B)
232247
233- boxdot (A:: AdjointAbsVec , B:: AdjointAbsVec ) = adjoint (adjoint (B) ⊡ adjoint (A))
234- boxdot (A:: AdjointAbsVec , B:: TransposeAbsVec ) = vec (A) ⊡ vec (B)
235- boxdot (A:: TransposeAbsVec , B:: AdjointAbsVec ) = vec (A) ⊡ vec (B)
236- boxdot (A:: TransposeAbsVec , B:: TransposeAbsVec ) = transpose (transpose (B) ⊡ transpose (A))
248+ boxdot (A:: AdjointAbsVec , B:: AdjointAbsVec , :: Val{1} ) = adjoint (adjoint (B) ⊡ adjoint (A))
249+ boxdot (A:: AdjointAbsVec , B:: TransposeAbsVec , :: Val{1} ) = vec (A) ⊡ vec (B)
250+ boxdot (A:: TransposeAbsVec , B:: AdjointAbsVec , :: Val{1} ) = vec (A) ⊡ vec (B)
251+ boxdot (A:: TransposeAbsVec , B:: TransposeAbsVec , :: Val{1} ) = transpose (transpose (B) ⊡ transpose (A))
237252
238253# ... with a matrix (returning another such)
239- boxdot (A:: AdjointAbsVec , B:: AbstractMatrix ) = A * B
240- boxdot (A:: TransposeAbsVec , B:: AbstractMatrix ) = A * B
254+ boxdot (A:: AdjointAbsVec , B:: AbstractMatrix , :: Val{1} ) = A * B
255+ boxdot (A:: TransposeAbsVec , B:: AbstractMatrix , :: Val{1} ) = A * B
241256
242- boxdot (A:: AbstractMatrix , B:: AdjointAbsVec ) = (B' ⊡ A' )'
243- boxdot (A:: AbstractMatrix , B:: TransposeAbsVec ) = transpose (transpose (B) ⊡ transpose (A))
257+ boxdot (A:: AbstractMatrix , B:: AdjointAbsVec , :: Val{1} ) = (B' ⊡ A' )'
258+ boxdot (A:: AbstractMatrix , B:: TransposeAbsVec , :: Val{1} ) = transpose (transpose (B) ⊡ transpose (A))
244259
245260# ... and with higher-dim (returning a plain array)
246- boxdot (A:: AdjointAbsVec , B:: AbstractArray ) = vec (A) ⊡ B
247- boxdot (A:: TransposeAbsVec , B:: AbstractArray ) = vec (A) ⊡ B
261+ boxdot (A:: AdjointAbsVec , B:: AbstractArray , :: Val{1} ) = vec (A) ⊡ B
262+ boxdot (A:: TransposeAbsVec , B:: AbstractArray , :: Val{1} ) = vec (A) ⊡ B
248263
249- boxdot (A:: AbstractArray , B:: AdjointAbsVec ) = A ⊡ vec (B)
250- boxdot (A:: AbstractArray , B:: TransposeAbsVec ) = A ⊡ vec (B)
264+ boxdot (A:: AbstractArray , B:: AdjointAbsVec , :: Val{1} ) = A ⊡ vec (B)
265+ boxdot (A:: AbstractArray , B:: TransposeAbsVec , :: Val{1} ) = A ⊡ vec (B)
251266
252267
253268"""
@@ -260,25 +275,30 @@ function boxdot! end
260275
261276if VERSION < v " 1.3" # Then 5-arg mul! isn't defined
262277
263- function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray )
264- szY = prod (size (A)[1 : end - 1 ]), prod (size (B)[2 : end ])
265- mul! (reshape (Y, szY), _squash_left (A), _squash_right (B))
278+ function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , :: Val{N} ) where {N}
279+ _check_boxdot_axes (A, B, Val (N))
280+ szY = prod (size (A)[1 : end - N]), prod (size (B)[1 + N: end ])
281+ mul! (reshape (Y, szY), _squash_left (A, Val (N)), _squash_right (B, Val (N)))
266282 Y
267283 end
268284
269- boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ) = boxdot! (Y, A, vec (B))
285+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray ) = boxdot! (Y, A, B, Val (1 ))
286+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ) = boxdot! (Y, A, vec (B), Val (1 ))
270287
271288else
272289
273- function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number = true , β:: Number = false )
274- szY = prod (size (A)[1 : end - 1 ]), prod (size (B)[2 : end ])
275- mul! (reshape (Y, szY), _squash_left (A), _squash_right (B), α, β)
290+ function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , :: Val{N} , α:: Number = true , β:: Number = false ) where {N}
291+ _check_boxdot_axes (A, B, Val (N))
292+ szY = prod (size (A)[1 : end - N]), prod (size (B)[1 + N: end ])
293+ mul! (reshape (Y, szY), _squash_left (A, Val (N)), _squash_right (B, Val (N)), α, β)
276294 Y
277295 end
278296
297+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number = true , β:: Number = false ) = boxdot! (Y, A, B, Val (1 ), α, β)
298+
279299 # For boxdot!, only where mul! behaves differently:
280300 boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ,
281- α:: Number = true , β:: Number = false ) = boxdot! (Y, A, vec (B), α, β)
301+ α:: Number = true , β:: Number = false ) = boxdot! (Y, A, vec (B), Val ( 1 ), α, β)
282302
283303end
284304
0 commit comments