Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,34 +136,67 @@ end
) = Operators.is_valid_index(space, ij, slabidx)

##### shmem fd kernel partition
"""
fd_shmem_stencil_partition(us, n_face_levels, n_max_threads)

Compute thread/block partition for finite difference shmem kernels.

Uses 3D thread blocks: (Nv, Ni, Nj) where:
- Nv threads handle vertical levels (up to n_face_levels)
- Ni × Nj threads handle horizontal nodal points within each element
- Each block processes one horizontal element (h)

This achieves ~1024 threads/block for typical Nv=64, Ni=Nj=4 configurations,
improving GPU occupancy compared to the previous 1D (Nv,) layout.
"""
@inline function fd_shmem_stencil_partition(
us::DataLayouts.UniversalSize,
n_face_levels::Integer,
n_max_threads::Integer = 256;
n_max_threads::Integer = 1024;
)
(Ni, Nj, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = n_face_levels
@assert Nvthreads <= maximum_allowable_threads()[1] "Number of vertical face levels cannot exceed $(maximum_allowable_threads()[1])"
Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face

# Check thread limits
max_threads = maximum_allowable_threads()
@assert Nvthreads <= max_threads[1] "Number of vertical face levels ($Nvthreads) cannot exceed $(max_threads[1])"
@assert Ni <= max_threads[2] "Ni ($Ni) cannot exceed $(max_threads[2])"
@assert Nj <= max_threads[3] "Nj ($Nj) cannot exceed $(max_threads[3])"

total_threads = Nvthreads * Ni * Nj
@assert total_threads <= n_max_threads "Total threads ($total_threads) exceeds max ($n_max_threads)"

return (;
threads = (Nvthreads,),
blocks = (Nh, Nvblocks, Ni * Nj),
threads = (Nvthreads, Ni, Nj),
blocks = (Nh,),
Nvthreads,
)
end
"""
fd_shmem_stencil_universal_index(space, us)

Compute the universal CartesianIndex for the current thread in 3D thread block layout.

Thread layout: (tv, ti, tj) where tv=vertical, ti/tj=horizontal nodal indices.
Block layout: (h,) where h=horizontal element index.

Returns CartesianIndex((i, j, 1, v, h)) for valid threads.
"""
@inline function fd_shmem_stencil_universal_index(
space::Spaces.AbstractSpace,
us,
)
(tv,) = CUDA.threadIdx()
(h, bv, ij) = CUDA.blockIdx()
v = tv + (bv - 1) * CUDA.blockDim().x
(Ni, Nj, _, _, _) = DataLayouts.universal_size(us)
if Ni * Nj < ij
return CartesianIndex((-1, -1, 1, -1, -1))
end
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
# 3D thread indexing: (v, i, j)
tv = CUDA.threadIdx().x # vertical level within block
ti = CUDA.threadIdx().y # horizontal nodal point i
tj = CUDA.threadIdx().z # horizontal nodal point j
h = CUDA.blockIdx().x # horizontal element

v = tv # Direct mapping: thread index = vertical level
i = ti
j = tj

return CartesianIndex((i, j, 1, v, h))
end
@inline fd_shmem_stencil_is_valid_index(I::CI5, us::UniversalSize) =
1 I[5] DataLayouts.get_Nh(us)
1 <= I[5] <= DataLayouts.get_Nh(us)
140 changes: 86 additions & 54 deletions ext/cuda/operators_fd_shmem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@ Base.@propagate_inbounds function fd_operator_shmem(
op::Operators.DivergenceF2C,
args...,
)
# allocate temp output
# allocate temp output and geometry cache
RT = return_eltype(op, args...)
FT = eltype(RT) # Get the float type from return type
Ju³ = CUDA.CuStaticSharedArray(RT, interior_size(shmem_params))
lJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params))
rJu³ = CUDA.CuStaticSharedArray(RT, boundary_size(shmem_params))
return (Ju³, lJu³, rJu³)
# Cache invJ to avoid repeated global memory reads
invJ_shmem = CUDA.CuStaticSharedArray(FT, interior_size(shmem_params))
return (Ju³, lJu³, rJu³, invJ_shmem)
end

Base.@propagate_inbounds function fd_operator_fill_shmem!(
op::Operators.DivergenceF2C,
(Ju³, lJu³, rJu³),
(Ju³, lJu³, rJu³, invJ_shmem),
bc_bds,
arg_space,
space,
Expand All @@ -29,31 +32,35 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
arg,
)
@inbounds begin
vt = threadIdx().x
vt = threadIdx().x # vertical level
ti = threadIdx().y # horizontal i
tj = threadIdx().z # horizontal j
lg = Geometry.LocalGeometry(space, idx, hidx)

# Cache invJ for use in evaluate - each face level stores invJ for the center below it
if !on_boundary(idx, space, op)
u³ = Operators.getidx(space, arg, idx, hidx)
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
Ju³[vt, ti, tj] = Geometry.Jcontravariant3(u³, lg)
# Cache invJ for the center at index vt (center below this face)
invJ_shmem[vt, ti, tj] = lg.invJ
elseif on_left_boundary(idx, space, op)
bloc = Operators.left_boundary_window(space)
bc = Operators.get_boundary(op, bloc)
ub = Operators.getidx(space, bc.val, nothing, hidx)
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
if bc isa Operators.SetValue
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
lJu³[ti, tj] = Geometry.Jcontravariant3(ub, lg)
elseif bc isa Operators.SetDivergence
bJu³[1] = ub
lJu³[ti, tj] = ub
elseif bc isa Operators.Extrapolate # no shmem needed
end
elseif on_right_boundary(idx, space, op)
bloc = Operators.right_boundary_window(space)
bc = Operators.get_boundary(op, bloc)
ub = Operators.getidx(space, bc.val, nothing, hidx)
bJu³ = on_left_boundary(idx, space) ? lJu³ : rJu³
if bc isa Operators.SetValue
bJu³[1] = Geometry.Jcontravariant3(ub, lg)
rJu³[ti, tj] = Geometry.Jcontravariant3(ub, lg)
elseif bc isa Operators.SetDivergence
bJu³[1] = ub
rJu³[ti, tj] = ub
elseif bc isa Operators.Extrapolate # no shmem needed
end
end
Expand All @@ -63,19 +70,22 @@ end

Base.@propagate_inbounds function fd_operator_evaluate(
op::Operators.DivergenceF2C,
(Ju³, lJu³, rJu³),
(Ju³, lJu³, rJu³, invJ_shmem),
space,
idx::Integer,
hidx,
arg,
)
@inbounds begin
vt = threadIdx().x
lg = Geometry.LocalGeometry(space, idx, hidx)
vt = threadIdx().x # vertical level
ti = threadIdx().y # horizontal i
tj = threadIdx().z # horizontal j
# Use cached invJ instead of reading LocalGeometry from global memory
invJ = invJ_shmem[vt, ti, tj]
if !on_boundary(idx, space, op)
Ju³₋ = Ju³[vt] # corresponds to idx - half
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ
Ju³₋ = Ju³[vt, ti, tj] # corresponds to idx - half
Ju³₊ = Ju³[vt + 1, ti, tj] # corresponds to idx + half
return (Ju³₊ ⊟ Ju³₋) ⊠ invJ
else
bloc =
on_left_boundary(idx, space, op) ?
Expand All @@ -85,22 +95,21 @@ Base.@propagate_inbounds function fd_operator_evaluate(
@assert bc isa Operators.SetValue || bc isa Operators.SetDivergence
if on_left_boundary(idx, space)
if bc isa Operators.SetValue
Ju³₋ = lJu³[1] # corresponds to idx - half
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ
Ju³₋ = lJu³[ti, tj] # corresponds to idx - half
Ju³₊ = Ju³[vt + 1, ti, tj] # corresponds to idx + half
return (Ju³₊ ⊟ Ju³₋) ⊠ invJ
else
# @assert bc isa Operators.SetDivergence
return lJu³[1]
return lJu³[ti, tj]
end
else
@assert on_right_boundary(idx, space)
if bc isa Operators.SetValue
Ju³₋ = Ju³[vt] # corresponds to idx - half
Ju³₊ = rJu³[1] # corresponds to idx + half
return (Ju³₊ ⊟ Ju³₋) ⊠ lg.invJ
Ju³₋ = Ju³[vt, ti, tj] # corresponds to idx - half
Ju³₊ = rJu³[ti, tj] # corresponds to idx + half
return (Ju³₊ ⊟ Ju³₋) ⊠ invJ
else
@assert bc isa Operators.SetDivergence
return rJu³[1]
return rJu³[ti, tj]
end
end
end
Expand Down Expand Up @@ -133,10 +142,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
)
@inbounds begin
is_out_of_bounds(idx, space) && return nothing
vt = threadIdx().x
vt = threadIdx().x # vertical level
ti = threadIdx().y # horizontal i
tj = threadIdx().z # horizontal j
cov3 = Geometry.Covariant3Vector(1)
if in_domain(idx, arg_space)
u[vt] = cov3 ⊗ Operators.getidx(space, arg, idx, hidx)
u[vt, ti, tj] = cov3 ⊗ Operators.getidx(space, arg, idx, hidx)
end
if on_any_boundary(idx, space, op)
lloc = Operators.left_boundary_window(space)
Expand All @@ -147,12 +158,19 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
bc = Operators.get_boundary(op, bloc)
@assert bc isa Operators.SetValue || bc isa Operators.SetGradient
ub = Operators.getidx(space, bc.val, nothing, hidx)
bu = on_left_boundary(idx, space) ? lb : rb
if bc isa Operators.SetValue
bu[1] = cov3 ⊗ ub
if on_left_boundary(idx, space)
lb[ti, tj] = cov3 ⊗ ub
else
rb[ti, tj] = cov3 ⊗ ub
end
elseif bc isa Operators.SetGradient
lg = Geometry.LocalGeometry(space, idx, hidx)
bu[1] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
if on_left_boundary(idx, space)
lb[ti, tj] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
else
rb[ti, tj] = Geometry.project(Geometry.Covariant3Axis(), ub, lg)
end
elseif bc isa Operators.Extrapolate # no shmem needed
end
end
Expand All @@ -169,11 +187,13 @@ Base.@propagate_inbounds function fd_operator_evaluate(
args...,
)
@inbounds begin
vt = threadIdx().x
vt = threadIdx().x # vertical level
ti = threadIdx().y # horizontal i
tj = threadIdx().z # horizontal j
lg = Geometry.LocalGeometry(space, idx, hidx)
if !on_boundary(idx, space, op)
u₋ = u[vt - 1] # corresponds to idx - half
u₊ = u[vt] # corresponds to idx + half
u₋ = u[vt - 1, ti, tj] # corresponds to idx - half
u₊ = u[vt, ti, tj] # corresponds to idx + half
return u₊ ⊟ u₋
else
bloc =
Expand All @@ -184,15 +204,15 @@ Base.@propagate_inbounds function fd_operator_evaluate(
@assert bc isa Operators.SetValue
if on_left_boundary(idx, space)
if bc isa Operators.SetValue
u₋ = 2 * lb[1] # corresponds to idx - half
u₊ = 2 * u[vt] # corresponds to idx + half
u₋ = 2 * lb[ti, tj] # corresponds to idx - half
u₊ = 2 * u[vt, ti, tj] # corresponds to idx + half
return u₊ ⊟ u₋
end
else
@assert on_right_boundary(idx, space)
if bc isa Operators.SetValue
u₋ = 2 * u[vt - 1] # corresponds to idx - half
u₊ = 2 * rb[1] # corresponds to idx + half
u₋ = 2 * u[vt - 1, ti, tj] # corresponds to idx - half
u₊ = 2 * rb[ti, tj] # corresponds to idx + half
return u₊ ⊟ u₋
end
end
Expand Down Expand Up @@ -226,9 +246,12 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
)
@inbounds begin
is_out_of_bounds(idx, space) && return nothing
vt = threadIdx().x # vertical level
ti = threadIdx().y # horizontal i
tj = threadIdx().z # horizontal j
ᶜidx = get_cent_idx(idx)
if in_domain(idx, arg_space)
u[idx] = Operators.getidx(space, arg, idx, hidx)
u[vt, ti, tj] = Operators.getidx(space, arg, idx, hidx)
else
lloc = Operators.left_boundary_window(space)
rloc = Operators.right_boundary_window(space)
Expand All @@ -242,16 +265,23 @@ Base.@propagate_inbounds function fd_operator_fill_shmem!(
bc isa Operators.NullBoundaryCondition
if bc isa Operators.NullBoundaryCondition ||
bc isa Operators.Extrapolate
u[idx] = Operators.getidx(space, arg, idx, hidx)
u[vt, ti, tj] = Operators.getidx(space, arg, idx, hidx)
return nothing
end
bu = on_left_boundary(idx, space) ? lb : rb
ub = Operators.getidx(space, bc.val, nothing, hidx)
if bc isa Operators.SetValue
bu[1] = ub
if on_left_boundary(idx, space)
lb[ti, tj] = ub
else
rb[ti, tj] = ub
end
elseif bc isa Operators.SetGradient
lg = Geometry.LocalGeometry(space, idx, hidx)
bu[1] = Geometry.covariant3(ub, lg)
if on_left_boundary(idx, space)
lb[ti, tj] = Geometry.covariant3(ub, lg)
else
rb[ti, tj] = Geometry.covariant3(ub, lg)
end
end
end
end
Expand All @@ -267,12 +297,14 @@ Base.@propagate_inbounds function fd_operator_evaluate(
args...,
)
@inbounds begin
vt = threadIdx().x
vt = threadIdx().x # vertical level
ti = threadIdx().y # horizontal i
tj = threadIdx().z # horizontal j
lg = Geometry.LocalGeometry(space, idx, hidx)
ᶜidx = get_cent_idx(idx)
if !on_boundary(idx, space, op)
u₋ = u[ᶜidx - 1] # corresponds to idx - half
u₊ = u[ᶜidx] # corresponds to idx + half
u₋ = u[vt - 1, ti, tj] # corresponds to idx - half
u₊ = u[vt, ti, tj] # corresponds to idx + half
return RecursiveApply.rdiv(u₊ ⊞ u₋, 2)
else
bloc =
Expand All @@ -285,26 +317,26 @@ Base.@propagate_inbounds function fd_operator_evaluate(
bc isa Operators.Extrapolate
if on_left_boundary(idx, space)
if bc isa Operators.SetValue
return lb[1]
return lb[ti, tj]
elseif bc isa Operators.SetGradient
u₋ = lb[1] # corresponds to idx - half
u₊ = u[ᶜidx] # corresponds to idx + half
u₋ = lb[ti, tj] # corresponds to idx - half
u₊ = u[vt, ti, tj] # corresponds to idx + half
return u₊ ⊟ RecursiveApply.rdiv(u₋, 2)
else
@assert bc isa Operators.Extrapolate
return u[ᶜidx]
return u[vt, ti, tj]
end
else
@assert on_right_boundary(idx, space)
if bc isa Operators.SetValue
return rb[1]
return rb[ti, tj]
elseif bc isa Operators.SetGradient
u₋ = u[ᶜidx - 1] # corresponds to idx - half
u₊ = rb[1] # corresponds to idx + half
u₋ = u[vt - 1, ti, tj] # corresponds to idx - half
u₊ = rb[ti, tj] # corresponds to idx + half
return u₋ ⊞ RecursiveApply.rdiv(u₊, 2)
else
@assert bc isa Operators.Extrapolate
return u[ᶜidx - 1]
return u[vt - 1, ti, tj]
end
end
end
Expand Down
Loading
Loading