Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6dac09e
Added 1D support of cos2pix example
kaanolgu Jan 16, 2026
f2c5643
added poisson_100 case
kaanolgu Jan 16, 2026
61f4003
Fix: 100 case is using transpose to reuse postprocess of 010 case
kaanolgu Jan 30, 2026
74658fb
fix: refactored the cos2pix example
kaanolgu Jan 30, 2026
360e175
fix: ci/cd build error fix and removed geo%BC by using periodic_x
kaanolgu Jan 30, 2026
8441de1
new: added the cos3pix test too
kaanolgu Feb 2, 2026
f18eb0e
add: 110 case initial setup
kaanolgu Feb 4, 2026
a31ed83
fix: cos2pix prints correct analytical solution for 110 case
kaanolgu Feb 4, 2026
acc143c
fix: cos2pix test case tests for poisson results and div(grad()) resu…
kaanolgu Feb 9, 2026
5c3de0a
add: cos(2pix)cos(2piy)cos(2piz)
kaanolgu Feb 9, 2026
5bf9fba
fix: fft_forward_cuda for 000 case
kaanolgu Feb 9, 2026
cccc393
fix:110 case spectral removed post process fwd/bwd x
kaanolgu Feb 9, 2026
6f6979d
fix: added VERBOSE option to hide prints
kaanolgu Feb 9, 2026
15ac987
added: input_000 and input_110
kaanolgu Feb 9, 2026
4f70faf
add: OpenMP missing functionalities error messages
kaanolgu Feb 10, 2026
9557e0c
fix: fprettify applied for styling
kaanolgu Feb 10, 2026
e216f36
fix: remove unused variable in cos2pix
kaanolgu Feb 10, 2026
175d9b1
fix:fprettify broken cuda calls
kaanolgu Feb 10, 2026
c2b1253
fix: fprettify ignore cuda chevrons
kaanolgu Feb 13, 2026
e87fdd9
new: moving from case to test
kaanolgu Feb 13, 2026
b4056c4
fix: comments for cuda poisson
kaanolgu Feb 13, 2026
dbde717
fix:remove unused enforce_periodicity_xy, undo_periodicity_xy functio…
kaanolgu Feb 16, 2026
5a953cb
Rename test_poisson.f90 to test_poisson_bc.f90
kaanolgu Feb 18, 2026
bb5c067
Update CMakeLists.txt to rename test_poisson to test_poisson_bc
kaanolgu Feb 18, 2026
c9dd785
fix: add missing x_sp_st offset to ix calculation in spectral processing
kaanolgu Feb 18, 2026
151033c
fix: multiple rank is not available for 110 case
kaanolgu Feb 18, 2026
2c3077e
fix: remove debug printouts
kaanolgu Feb 18, 2026
9190f99
fix: reuse fft_forward/backward_cuda for 010 and 110 cases
kaanolgu Feb 18, 2026
518db11
add cufft to forward and backward fft
ia267 Feb 18, 2026
be77aa2
ensure that plan3d_fw and plan3d_bw and runtime branching are always …
ia267 Feb 18, 2026
dc2e994
fix formatting issues
ia267 Feb 19, 2026
0652350
fix: the forward_cuda logic was affected by the merge
kaanolgu Feb 19, 2026
c80daab
fix: remove debug prints
kaanolgu Feb 19, 2026
a1876f4
fix: changes to cufft path to ensure boundary conditions are corrrect.
ia267 Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 281 additions & 7 deletions src/backend/cuda/kernels/spectral_processing.f90
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,54 @@ attributes(global) subroutine memcpy3D(dst, src, nx, ny, nz)
end if
end subroutine memcpy3D

attributes(global) subroutine memcpy3D_with_transpose(dst, src, nx, ny, nz)
!! Copy with transpose: src(nx, ny, nz) -> dst(ny, nx, nz)
!! Used for 100 case forward FFT
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: dst ! (ny+2, nx, nz) but we only write (ny, nx, nz)
real(dp), device, intent(in), dimension(:, :, :) :: src ! (nx, ny, nz)
integer, value, intent(in) :: nx, ny, nz

integer :: i, j, k

i = threadIdx%x + (blockIdx%x - 1)*blockDim%x ! iterates over nx
k = blockIdx%y ! nz

if (i <= nx) then
do j = 1, ny
! Transpose: dst(j, i, k) = src(i, j, k)
dst(j, i, k) = src(i, j, k)
end do
end if

end subroutine memcpy3D_with_transpose

attributes(global) subroutine memcpy3D_with_transpose_back( &
dst, src, nx, ny, nz &
)
!! Copy with transpose back: src(ny, nx, nz) -> dst(nx, ny, nz)
!! Used for 100 case backward FFT
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: dst ! (nx, ny, nz)
real(dp), device, intent(in), dimension(:, :, :) :: src ! (ny+2, nx, nz) but we only read (ny, nx, nz)
integer, value, intent(in) :: nx, ny, nz

integer :: i, j, k

i = threadIdx%x + (blockIdx%x - 1)*blockDim%x ! iterates over nx
k = blockIdx%y ! nz

if (i <= nx) then
do j = 1, ny
! Transpose back: dst(i, j, k) = src(j, i, k)
dst(i, j, k) = src(j, i, k)
end do
end if

end subroutine memcpy3D_with_transpose_back

attributes(global) subroutine process_spectral_000( &
div_u, waves, nx_spec, ny_spec, y_sp_st, nx, ny, nz, &
ax, bx, ay, by, az, bz &
Expand Down Expand Up @@ -604,24 +652,245 @@ attributes(global) subroutine process_spectral_010_bw( &

end subroutine process_spectral_010_bw

attributes(global) subroutine process_spectral_110( &
div_u, waves, nx_spec, ny_spec, x_sp_st, y_sp_st, nx, ny, nz, &
Comment thread
ia267 marked this conversation as resolved.
ax, bx, ay, by, az, bz &
)
!! Post-processes for Dirichlet BC in X and Y, periodic in Z
implicit none

complex(dp), device, intent(inout), dimension(:, :, :) :: div_u
complex(dp), device, intent(in), dimension(:, :, :) :: waves
real(dp), device, intent(in), dimension(:) :: ax, bx, ay, by, az, bz
integer, value, intent(in) :: nx_spec, ny_spec
integer, value, intent(in) :: x_sp_st, y_sp_st
integer, value, intent(in) :: nx, ny, nz

integer :: i, j, k, ix, iy, iz, iy_rev
real(dp) :: tmp_r, tmp_c, div_r, div_c
real(dp) :: l_r, l_c, r_r, r_c

i = threadIdx%x + (blockIdx%x - 1)*blockDim%x
k = blockIdx%y ! nz_spec

! ================================================================
! FORWARD PASS
! ================================================================

! Step 1: Normalise and periodic post-process in z
if (i <= nx_spec) then
do j = 1, ny_spec
ix = i + x_sp_st; iy = j + y_sp_st; iz = k

! normalisation
div_r = real(div_u(i, j, k), kind=dp)/nx/ny/nz
div_c = aimag(div_u(i, j, k))/nx/ny/nz

! postprocess in z (periodic)
tmp_r = div_r
tmp_c = div_c
div_r = tmp_r*bz(iz) + tmp_c*az(iz)
div_c = tmp_c*bz(iz) - tmp_r*az(iz)
if (iz > nz/2 + 1) div_r = -div_r
if (iz > nz/2 + 1) div_c = -div_c

! update the entry
div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
end do
end if

! Step 2: Paired even/odd splitting for y (Dirichlet direction)
if (i <= nx_spec) then
do j = 2, ny_spec/2 + 1
iy = j + y_sp_st
iy_rev = ny_spec - j + 2 + y_sp_st

l_r = real(div_u(i, j, k), kind=dp)
l_c = aimag(div_u(i, j, k))
r_r = real(div_u(i, ny_spec - j + 2, k), kind=dp)
r_c = aimag(div_u(i, ny_spec - j + 2, k))

! update the entry
div_u(i, j, k) = 0.5_dp*cmplx( & !&
l_r*by(iy) + l_c*ay(iy) + r_r*by(iy) - r_c*ay(iy), &
-l_r*ay(iy) + l_c*by(iy) + r_r*ay(iy) + r_c*by(iy), kind=dp &
)
div_u(i, ny_spec - j + 2, k) = 0.5_dp*cmplx( & !&
r_r*by(iy_rev) + r_c*ay(iy_rev) + l_r*by(iy_rev) - l_c*ay(iy_rev), &
-r_r*ay(iy_rev) + r_c*by(iy_rev) + l_r*ay(iy_rev) + l_c*by(iy_rev), &
kind=dp &
)
end do
end if

! NOTE: No x-direction paired splitting!
! For the R2C FFT output, the x-direction Dirichlet BC is handled
! implicitly through the waves_set.
! The R2C output only contains nx/2+1 values (wavenumbers 0 to nx/2),
! so there's no "pairing" to do

! ================================================================
! POISSON SOLVE
! ================================================================
if (i <= nx_spec) then
do j = 1, ny_spec
ix = i + x_sp_st

div_r = real(div_u(i, j, k), kind=dp)
div_c = aimag(div_u(i, j, k))

tmp_r = real(waves(i, j, k), kind=dp)
tmp_c = aimag(waves(i, j, k))
if (abs(tmp_r) < 1.e-16_dp) then
div_r = 0._dp
else
div_r = -div_r/tmp_r
end if
if (abs(tmp_c) < 1.e-16_dp) then
div_c = 0._dp
else
div_c = -div_c/tmp_c
end if

! update the entry
div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
! Zero out the mode at (nx/2+1, *, nz/2+1) for uniqueness
if (ix == nx/2 + 1 .and. k == nz/2 + 1) div_u(i, j, k) = 0._dp
end do
end if

! ================================================================
! BACKWARD PASS
! ================================================================

! Step 1: Paired even/odd recombination for y (Dirichlet direction)
if (i <= nx_spec) then
do j = 2, ny_spec/2 + 1
iy = j + y_sp_st
iy_rev = ny_spec - j + 2 + y_sp_st

l_r = real(div_u(i, j, k), kind=dp)
l_c = aimag(div_u(i, j, k))
r_r = real(div_u(i, ny_spec - j + 2, k), kind=dp)
r_c = aimag(div_u(i, ny_spec - j + 2, k))

! update the entry
div_u(i, j, k) = cmplx( & !&
l_r*by(iy) - l_c*ay(iy) + r_r*ay(iy) + r_c*by(iy), &
l_r*ay(iy) + l_c*by(iy) - r_r*by(iy) + r_c*ay(iy), kind=dp &
)
div_u(i, ny_spec - j + 2, k) = cmplx( & !&
r_r*by(iy_rev) - r_c*ay(iy_rev) + l_r*ay(iy_rev) + l_c*by(iy_rev), &
r_r*ay(iy_rev) + r_c*by(iy_rev) - l_r*by(iy_rev) + l_c*ay(iy_rev), &
kind=dp &
)
end do
end if

! Step 2: Periodic post-process in z (undo)
if (i <= nx_spec) then
do j = 1, ny_spec
iz = k

div_r = real(div_u(i, j, k), kind=dp)
div_c = aimag(div_u(i, j, k))

! post-process in z
tmp_r = div_r
tmp_c = div_c
div_r = tmp_r*bz(iz) - tmp_c*az(iz)
div_c = tmp_c*bz(iz) + tmp_r*az(iz)
if (iz > nz/2 + 1) div_r = -div_r
if (iz > nz/2 + 1) div_c = -div_c

! update the entry
div_u(i, j, k) = cmplx(div_r, div_c, kind=dp)
end do
end if

end subroutine process_spectral_110

attributes(global) subroutine enforce_periodicity_x(f_out, f_in, nx)
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: f_out
real(dp), device, intent(in), dimension(:, :, :) :: f_in
integer, value, intent(in) :: nx

integer :: i, j, k, n2

j = threadIdx%x
k = blockIdx%x
n2 = nx/2

do i = 1, n2
f_out(i, j, k) = f_in(2*i - 1, j, k)
end do
if (mod(nx, 2) == 1) then
! odd-size center entry
f_out(n2 + 1, j, k) = f_in(nx, j, k)
do i = n2 + 2, nx
f_out(i, j, k) = f_in(2*nx - 2*i + 2, j, k)
end do
else
do i = n2 + 1, nx
f_out(i, j, k) = f_in(2*nx - 2*i + 2, j, k)
end do
end if

end subroutine enforce_periodicity_x

attributes(global) subroutine undo_periodicity_x(f_out, f_in, nx)
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: f_out
real(dp), device, intent(in), dimension(:, :, :) :: f_in
integer, value, intent(in) :: nx

integer :: i, j, k, n2

j = threadIdx%x
k = blockIdx%x
n2 = nx/2

do i = 1, n2
f_out(2*i - 1, j, k) = f_in(i, j, k)
f_out(2*i, j, k) = f_in(nx - i + 1, j, k)
end do
if (mod(nx, 2) == 1) then
! odd-size center entry
f_out(nx, j, k) = f_in(n2 + 1, j, k)
end if

end subroutine undo_periodicity_x

attributes(global) subroutine enforce_periodicity_y(f_out, f_in, ny)
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: f_out
real(dp), device, intent(in), dimension(:, :, :) :: f_in
integer, value, intent(in) :: ny

integer :: i, j, k
integer :: i, j, k, n2

i = threadIdx%x
k = blockIdx%x
n2 = ny/2

do j = 1, ny/2
do j = 1, n2
f_out(i, j, k) = f_in(i, 2*j - 1, k)
end do
do j = ny/2 + 1, ny
f_out(i, j, k) = f_in(i, 2*ny - 2*j + 2, k)
end do
if (mod(ny, 2) == 1) then
! odd-size center entry
f_out(i, n2 + 1, k) = f_in(i, ny, k)
do j = n2 + 2, ny
f_out(i, j, k) = f_in(i, 2*ny - 2*j + 2, k)
end do
else
do j = n2 + 1, ny
f_out(i, j, k) = f_in(i, 2*ny - 2*j + 2, k)
end do
end if

end subroutine enforce_periodicity_y

Expand All @@ -632,15 +901,20 @@ attributes(global) subroutine undo_periodicity_y(f_out, f_in, ny)
real(dp), device, intent(in), dimension(:, :, :) :: f_in
integer, value, intent(in) :: ny

integer :: i, j, k
integer :: i, j, k, n2

i = threadIdx%x
k = blockIdx%x
n2 = ny/2

do j = 1, ny/2
do j = 1, n2
f_out(i, 2*j - 1, k) = f_in(i, j, k)
f_out(i, 2*j, k) = f_in(i, ny - j + 1, k)
end do
if (mod(ny, 2) == 1) then
! odd-size center entry
f_out(i, ny, k) = f_in(i, n2 + 1, k)
end if

end subroutine undo_periodicity_y

Expand Down
Loading