diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 90b3f43153..a2925e0473 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -136,34 +136,67 @@ end ) = Operators.is_valid_index(space, ij, slabidx) ##### shmem fd kernel partition +""" + fd_shmem_stencil_partition(us, n_face_levels, n_max_threads) + +Compute thread/block partition for finite difference shmem kernels. + +Uses 3D thread blocks: (Nv, Ni, Nj) where: +- Nv threads handle vertical levels (up to n_face_levels) +- Ni × Nj threads handle horizontal nodal points within each element +- Each block processes one horizontal element (h) + +This achieves ~1024 threads/block for typical Nv=64, Ni=Nj=4 configurations, +improving GPU occupancy compared to the previous 1D (Nv,) layout. +""" @inline function fd_shmem_stencil_partition( us::DataLayouts.UniversalSize, n_face_levels::Integer, - n_max_threads::Integer = 256; + n_max_threads::Integer = 1024; ) (Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us) Nvthreads = n_face_levels - @assert Nvthreads <= maximum_allowable_threads()[1] "Number of vertical face levels cannot exceed $(maximum_allowable_threads()[1])" - Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face + + # Check thread limits + max_threads = maximum_allowable_threads() + @assert Nvthreads <= max_threads[1] "Number of vertical face levels ($Nvthreads) cannot exceed $(max_threads[1])" + @assert Ni <= max_threads[2] "Ni ($Ni) cannot exceed $(max_threads[2])" + @assert Nj <= max_threads[3] "Nj ($Nj) cannot exceed $(max_threads[3])" + + total_threads = Nvthreads * Ni * Nj + @assert total_threads <= n_max_threads "Total threads ($total_threads) exceeds max ($n_max_threads)" + return (; - threads = (Nvthreads,), - blocks = (Nh, Nvblocks, Ni * Nj), + threads = (Nvthreads, Ni, Nj), + blocks = (Nh,), Nvthreads, ) end +""" + fd_shmem_stencil_universal_index(space, us) + +Compute the universal CartesianIndex for the current thread in 3D thread block layout. + +Thread layout: (tv, ti, tj) where tv=vertical, ti/tj=horizontal nodal indices. +Block layout: (h,) where h=horizontal element index. + +Returns CartesianIndex((i, j, 1, v, h)) for valid threads. +""" @inline function fd_shmem_stencil_universal_index( space::Spaces.AbstractSpace, us, ) - (tv,) = CUDA.threadIdx() - (h, bv, ij) = CUDA.blockIdx() - v = tv + (bv - 1) * CUDA.blockDim().x - (Ni, Nj, _, _, _) = DataLayouts.universal_size(us) - if Ni * Nj < ij - return CartesianIndex((-1, -1, 1, -1, -1)) - end - @inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I + # 3D thread indexing: (v, i, j) + tv = CUDA.threadIdx().x # vertical level within block + ti = CUDA.threadIdx().y # horizontal nodal point i + tj = CUDA.threadIdx().z # horizontal nodal point j + h = CUDA.blockIdx().x # horizontal element + + v = tv # Direct mapping: thread index = vertical level + i = ti + j = tj + return CartesianIndex((i, j, 1, v, h)) end @inline fd_shmem_stencil_is_valid_index(I::CI5, us::UniversalSize) = - 1 ≤ I[5] ≤ DataLayouts.get_Nh(us) + 1 <= I[5] <= DataLayouts.get_Nh(us) diff --git a/ext/cuda/operators_fd_shmem.jl b/ext/cuda/operators_fd_shmem.jl index da1d3ffdd9..249f287808 100644 --- a/ext/cuda/operators_fd_shmem.jl +++ b/ext/cuda/operators_fd_shmem.jl @@ -10,17 +10,20 @@ Base.@propagate_inbounds function fd_operator_shmem( op::Operators.DivergenceF2C, args..., ) - # allocate temp output + # allocate temp output and geometry cache RT = return_eltype(op, args...) + FT = eltype(RT) # Get the float type from return type Ju³ = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params)) lJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) rJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params)) - return (Ju³, lJu³, rJu³) + # Cache invJ to avoid repeated global memory reads + invJ_shmem = CUDA.CuStaticSharedArray(FT, interior_size(shmem_params)) + return (Ju³, lJu³, rJu³, invJ_shmem) end Base.@propagate_inbounds function fd_operator_fill_shmem!( op::Operators.DivergenceF2C, - (Ju³, lJu³, rJu³), + (Ju³, lJu³, rJu³, invJ_shmem), bc_bds, arg_space, space, @@ -29,31 +32,35 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( arg, ) @inbounds begin - vt = threadIdx().x + vt = threadIdx().x # vertical level + ti = threadIdx().y # horizontal i + tj = threadIdx().z # horizontal j lg = Geometry.LocalGeometry(space, idx, hidx) + + # Cache invJ for use in evaluate - each face level stores invJ for the center below it if !on_boundary(idx, space, op) u³ = Operators.getidx(space, arg, idx, hidx) - Ju³[vt] = Geometry.Jcontravariant3(u³, lg) + Ju³[vt, ti, tj] = Geometry.Jcontravariant3(u³, lg) + # Cache invJ for the center at index vt (center below this face) + invJ_shmem[vt, ti, tj] = lg.invJ elseif on_left_boundary(idx, space, op) bloc = Operators.left_boundary_window(space) bc = Operators.get_boundary(op, bloc) ub = Operators.getidx(space, bc.val, nothing, hidx) - bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³ if bc isa Operators.SetValue - bJu³[1] = Geometry.Jcontravariant3(ub, lg) + lJu³[ti, tj] = Geometry.Jcontravariant3(ub, lg) elseif bc isa Operators.SetDivergence - bJu³[1] = ub + lJu³[ti, tj] = ub elseif bc isa Operators.Extrapolate # no shmem needed end elseif on_right_boundary(idx, space, op) bloc = Operators.right_boundary_window(space) bc = Operators.get_boundary(op, bloc) ub = Operators.getidx(space, bc.val, nothing, hidx) - bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³ if bc isa Operators.SetValue - bJu³[1] = Geometry.Jcontravariant3(ub, lg) + rJu³[ti, tj] = Geometry.Jcontravariant3(ub, lg) elseif bc isa Operators.SetDivergence - bJu³[1] = ub + rJu³[ti, tj] = ub elseif bc isa Operators.Extrapolate # no shmem needed end end @@ -63,19 +70,22 @@ end Base.@propagate_inbounds function fd_operator_evaluate( op::Operators.DivergenceF2C, - (Ju³, lJu³, rJu³), + (Ju³, lJu³, rJu³, invJ_shmem), space, idx::Integer, hidx, arg, ) @inbounds begin - vt = threadIdx().x - lg = Geometry.LocalGeometry(space, idx, hidx) + vt = threadIdx().x # vertical level + ti = threadIdx().y # horizontal i + tj = threadIdx().z # horizontal j + # Use cached invJ instead of reading LocalGeometry from global memory + invJ = invJ_shmem[vt, ti, tj] if !on_boundary(idx, space, op) - Ju³₋ = Ju³[vt] # corresponds to idx - half - Ju³₊ = Ju³[vt + 1] # corresponds to idx + half - return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ + Ju³₋ = Ju³[vt, ti, tj] # corresponds to idx - half + Ju³₊ = Ju³[vt + 1, ti, tj] # corresponds to idx + half + return (Ju³₊ ⊟ Ju³₋) ⊠ invJ else bloc = on_left_boundary(idx, space, op) ? @@ -85,22 +95,21 @@ Base.@propagate_inbounds function fd_operator_evaluate( @assert bc isa Operators.SetValue || bc isa Operators.SetDivergence if on_left_boundary(idx, space) if bc isa Operators.SetValue - Ju³₋ = lJu³[1] # corresponds to idx - half - Ju³₊ = Ju³[vt + 1] # corresponds to idx + half - return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ + Ju³₋ = lJu³[ti, tj] # corresponds to idx - half + Ju³₊ = Ju³[vt + 1, ti, tj] # corresponds to idx + half + return (Ju³₊ ⊟ Ju³₋) ⊠ invJ else - # @assert bc isa Operators.SetDivergence - return lJu³[1] + return lJu³[ti, tj] end else @assert on_right_boundary(idx, space) if bc isa Operators.SetValue - Ju³₋ = Ju³[vt] # corresponds to idx - half - Ju³₊ = rJu³[1] # corresponds to idx + half - return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ + Ju³₋ = Ju³[vt, ti, tj] # corresponds to idx - half + Ju³₊ = rJu³[ti, tj] # corresponds to idx + half + return (Ju³₊ ⊟ Ju³₋) ⊠ invJ else @assert bc isa Operators.SetDivergence - return rJu³[1] + return rJu³[ti, tj] end end end @@ -133,10 +142,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( ) @inbounds begin is_out_of_bounds(idx, space) && return nothing - vt = threadIdx().x + vt = threadIdx().x # vertical level + ti = threadIdx().y # horizontal i + tj = threadIdx().z # horizontal j cov3 = Geometry.Covariant3Vector(1) if in_domain(idx, arg_space) - u[vt] = cov3 ⊗ Operators.getidx(space, arg, idx, hidx) + u[vt, ti, tj] = cov3 ⊗ Operators.getidx(space, arg, idx, hidx) end if on_any_boundary(idx, space, op) lloc = Operators.left_boundary_window(space) @@ -147,12 +158,19 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( bc = Operators.get_boundary(op, bloc) @assert bc isa Operators.SetValue || bc isa Operators.SetGradient ub = Operators.getidx(space, bc.val, nothing, hidx) - bu = on_left_boundary(idx, space) ? lb : rb if bc isa Operators.SetValue - bu[1] = cov3 ⊗ ub + if on_left_boundary(idx, space) + lb[ti, tj] = cov3 ⊗ ub + else + rb[ti, tj] = cov3 ⊗ ub + end elseif bc isa Operators.SetGradient lg = Geometry.LocalGeometry(space, idx, hidx) - bu[1] = Geometry.project(Geometry.Covariant3Axis(), ub, lg) + if on_left_boundary(idx, space) + lb[ti, tj] = Geometry.project(Geometry.Covariant3Axis(), ub, lg) + else + rb[ti, tj] = Geometry.project(Geometry.Covariant3Axis(), ub, lg) + end elseif bc isa Operators.Extrapolate # no shmem needed end end @@ -169,11 +187,13 @@ Base.@propagate_inbounds function fd_operator_evaluate( args..., ) @inbounds begin - vt = threadIdx().x + vt = threadIdx().x # vertical level + ti = threadIdx().y # horizontal i + tj = threadIdx().z # horizontal j lg = Geometry.LocalGeometry(space, idx, hidx) if !on_boundary(idx, space, op) - u₋ = u[vt - 1] # corresponds to idx - half - u₊ = u[vt] # corresponds to idx + half + u₋ = u[vt - 1, ti, tj] # corresponds to idx - half + u₊ = u[vt, ti, tj] # corresponds to idx + half return u₊ ⊟ u₋ else bloc = @@ -184,15 +204,15 @@ Base.@propagate_inbounds function fd_operator_evaluate( @assert bc isa Operators.SetValue if on_left_boundary(idx, space) if bc isa Operators.SetValue - u₋ = 2 * lb[1] # corresponds to idx - half - u₊ = 2 * u[vt] # corresponds to idx + half + u₋ = 2 * lb[ti, tj] # corresponds to idx - half + u₊ = 2 * u[vt, ti, tj] # corresponds to idx + half return u₊ ⊟ u₋ end else @assert on_right_boundary(idx, space) if bc isa Operators.SetValue - u₋ = 2 * u[vt - 1] # corresponds to idx - half - u₊ = 2 * rb[1] # corresponds to idx + half + u₋ = 2 * u[vt - 1, ti, tj] # corresponds to idx - half + u₊ = 2 * rb[ti, tj] # corresponds to idx + half return u₊ ⊟ u₋ end end @@ -226,9 +246,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( ) @inbounds begin is_out_of_bounds(idx, space) && return nothing + vt = threadIdx().x # vertical level + ti = threadIdx().y # horizontal i + tj = threadIdx().z # horizontal j ᶜidx = get_cent_idx(idx) if in_domain(idx, arg_space) - u[idx] = Operators.getidx(space, arg, idx, hidx) + u[vt, ti, tj] = Operators.getidx(space, arg, idx, hidx) else lloc = Operators.left_boundary_window(space) rloc = Operators.right_boundary_window(space) @@ -242,16 +265,23 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!( bc isa Operators.NullBoundaryCondition if bc isa Operators.NullBoundaryCondition || bc isa Operators.Extrapolate - u[idx] = Operators.getidx(space, arg, idx, hidx) + u[vt, ti, tj] = Operators.getidx(space, arg, idx, hidx) return nothing end - bu = on_left_boundary(idx, space) ? lb : rb ub = Operators.getidx(space, bc.val, nothing, hidx) if bc isa Operators.SetValue - bu[1] = ub + if on_left_boundary(idx, space) + lb[ti, tj] = ub + else + rb[ti, tj] = ub + end elseif bc isa Operators.SetGradient lg = Geometry.LocalGeometry(space, idx, hidx) - bu[1] = Geometry.covariant3(ub, lg) + if on_left_boundary(idx, space) + lb[ti, tj] = Geometry.covariant3(ub, lg) + else + rb[ti, tj] = Geometry.covariant3(ub, lg) + end end end end @@ -267,12 +297,14 @@ Base.@propagate_inbounds function fd_operator_evaluate( args..., ) @inbounds begin - vt = threadIdx().x + vt = threadIdx().x # vertical level + ti = threadIdx().y # horizontal i + tj = threadIdx().z # horizontal j lg = Geometry.LocalGeometry(space, idx, hidx) ᶜidx = get_cent_idx(idx) if !on_boundary(idx, space, op) - u₋ = u[ᶜidx - 1] # corresponds to idx - half - u₊ = u[ᶜidx] # corresponds to idx + half + u₋ = u[vt - 1, ti, tj] # corresponds to idx - half + u₊ = u[vt, ti, tj] # corresponds to idx + half return RecursiveApply.rdiv(u₊ ⊞ u₋, 2) else bloc = @@ -285,26 +317,26 @@ Base.@propagate_inbounds function fd_operator_evaluate( bc isa Operators.Extrapolate if on_left_boundary(idx, space) if bc isa Operators.SetValue - return lb[1] + return lb[ti, tj] elseif bc isa Operators.SetGradient - u₋ = lb[1] # corresponds to idx - half - u₊ = u[ᶜidx] # corresponds to idx + half + u₋ = lb[ti, tj] # corresponds to idx - half + u₊ = u[vt, ti, tj] # corresponds to idx + half return u₊ ⊟ RecursiveApply.rdiv(u₋, 2) else @assert bc isa Operators.Extrapolate - return u[ᶜidx] + return u[vt, ti, tj] end else @assert on_right_boundary(idx, space) if bc isa Operators.SetValue - return rb[1] + return rb[ti, tj] elseif bc isa Operators.SetGradient - u₋ = u[ᶜidx - 1] # corresponds to idx - half - u₊ = rb[1] # corresponds to idx + half + u₋ = u[vt - 1, ti, tj] # corresponds to idx - half + u₊ = rb[ti, tj] # corresponds to idx + half return u₋ ⊞ RecursiveApply.rdiv(u₊, 2) else @assert bc isa Operators.Extrapolate - return u[ᶜidx - 1] + return u[vt - 1, ti, tj] end end end diff --git a/ext/cuda/operators_fd_shmem_common.jl b/ext/cuda/operators_fd_shmem_common.jl index 6aded04ead..d2d465cb67 100644 --- a/ext/cuda/operators_fd_shmem_common.jl +++ b/ext/cuda/operators_fd_shmem_common.jl @@ -265,34 +265,55 @@ end Return the total number of shared memory (in bytes) for the given broadcast expression. """ -@inline fd_shmem_needed_per_column(bc) = fd_shmem_needed_per_column(0, bc) -@inline fd_shmem_needed_per_column(shmem_bytes, obj) = shmem_bytes +@inline function fd_operator_shmem_size(op, Nv, args...) + RT = return_eltype(op, args...) + # Default: Nv elements for interior, 1 for each boundary (left/right) -> Nv + 2 + return (Nv + 2) * sizeof(RT) +end + +@inline function fd_operator_shmem_size(op::Operators.DivergenceF2C, Nv, args...) + RT = return_eltype(op, args...) + FT = eltype(RT) + # DivergenceF2C: (Nv + 2) * sizeof(RT) for Ju lines + Nv * sizeof(FT) for invJ cache + return (Nv + 2) * sizeof(RT) + Nv * sizeof(FT) +end + +""" + fd_shmem_needed_per_column(Nv, bc) + +Return the total number of shared memory (in bytes) for the given +broadcast expression per column (i,j). +""" +@inline fd_shmem_needed_per_column(Nv, bc) = fd_shmem_needed_per_column(0, Nv, bc) +@inline fd_shmem_needed_per_column(shmem_bytes, Nv, obj) = shmem_bytes @inline fd_shmem_needed_per_column( shmem_bytes, + Nv, bc::Broadcasted{Style}, ) where {Style} = - shmem_bytes + _fd_shmem_needed_per_column(shmem_bytes, bc.args) + shmem_bytes + _fd_shmem_needed_per_column(shmem_bytes, Nv, bc.args) @inline function fd_shmem_needed_per_column( shmem_bytes, + Nv, sbc::StencilBroadcasted{Style}, ) where {Style} - shmem_bytes₀ = _fd_shmem_needed_per_column(shmem_bytes, sbc.args) + shmem_bytes₀ = _fd_shmem_needed_per_column(shmem_bytes, Nv, sbc.args) return if Operators.fd_shmem_is_supported(sbc) - sizeof(return_eltype(sbc.op, sbc.args...)) + shmem_bytes₀ + fd_operator_shmem_size(sbc.op, Nv, sbc.args...) + shmem_bytes₀ else shmem_bytes₀ end end -@inline _fd_shmem_needed_per_column(shmem_bytes::Integer, ::Tuple{}) = +@inline _fd_shmem_needed_per_column(shmem_bytes::Integer, Nv, ::Tuple{}) = shmem_bytes -@inline _fd_shmem_needed_per_column(shmem_bytes::Integer, args::Tuple{Any}) = - shmem_bytes + fd_shmem_needed_per_column(shmem_bytes::Integer, args[1]) -@inline _fd_shmem_needed_per_column(shmem_bytes::Integer, args::Tuple) = +@inline _fd_shmem_needed_per_column(shmem_bytes::Integer, Nv, args::Tuple{Any}) = + shmem_bytes + fd_shmem_needed_per_column(shmem_bytes::Integer, Nv, args[1]) +@inline _fd_shmem_needed_per_column(shmem_bytes::Integer, Nv, args::Tuple) = shmem_bytes + - fd_shmem_needed_per_column(shmem_bytes::Integer, args[1]) + - _fd_shmem_needed_per_column(shmem_bytes::Integer, Base.tail(args)) + fd_shmem_needed_per_column(shmem_bytes::Integer, Nv, args[1]) + + _fd_shmem_needed_per_column(shmem_bytes::Integer, Nv, Base.tail(args)) get_arg_space(bc::StencilBroadcasted, args::Tuple{}) = axes(bc) diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index e5a95c4884..6b473b3a9a 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -21,9 +21,15 @@ Base.Broadcast.BroadcastStyle( include("operators_fd_shmem_is_supported.jl") -struct ShmemParams{Nv} end -interior_size(::ShmemParams{Nv}) where {Nv} = (Nv,) -boundary_size(::ShmemParams{Nv}) where {Nv} = (1,) +# ShmemParams holds dimensions for shared memory allocation +# With 3D thread blocks, we have Ni×Nj columns per block, each with Nv levels +struct ShmemParams{Nv, Ni, Nj} end + +# Interior shmem: one column buffer per (i,j) point = (Nv, Ni, Nj) +interior_size(::ShmemParams{Nv, Ni, Nj}) where {Nv, Ni, Nj} = (Nv, Ni, Nj) + +# Boundary shmem: one value per (i,j) point for each boundary = (Ni, Nj) +boundary_size(::ShmemParams{Nv, Ni, Nj}) where {Nv, Ni, Nj} = (Ni, Nj) function Base.copyto!( out::Field, @@ -44,22 +50,24 @@ function Base.copyto!( n_face_levels = Spaces.nlevels(fspace) high_resolution = !(n_face_levels ≤ 256) # https://github.com/JuliaGPU/CUDA.jl/issues/2672 - # max_shmem = 166912 # CUDA.limit(CUDA.LIMIT_SHMEM_SIZE) # max_shmem = CUDA.attribute( device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, ) - total_shmem = fd_shmem_needed_per_column(bc) - enough_shmem = total_shmem ≤ max_shmem # TODO: Use CUDA.limit(CUDA.LIMIT_SHMEM_SIZE) to determine how much shmem should be used # TODO: add shmem support for masked operations + (Ni, Nj, _, _, _) = DataLayouts.universal_size(us) + # With 3D thread blocks, we have Ni×Nj columns per block + total_shmem_per_block = fd_shmem_needed_per_column(n_face_levels, bc) * Ni * Nj + enough_shmem = total_shmem_per_block ≤ max_shmem + if Operators.any_fd_shmem_supported(bc) && !high_resolution && mask isa NoMask && enough_shmem && Operators.use_fd_shmem() - shmem_params = ShmemParams{n_face_levels}() + shmem_params = ShmemParams{n_face_levels, Ni, Nj}() p = fd_shmem_stencil_partition(us, n_face_levels) args = ( strip_space(out, space),