From c7907590aa7e6f2b1c6137966ceedd68d5a0e281 Mon Sep 17 00:00:00 2001 From: "tianxiang.wang@metax-tech.com" Date: Wed, 3 Sep 2025 09:16:39 +0000 Subject: [PATCH 1/2] =?UTF-8?q?Perf:=20batch=20FFT=20and=20surrounding=20o?= =?UTF-8?q?perators=20for=20performance=20Signed-off-by=EF=BC=9ATianxiang?= =?UTF-8?q?=20Wang,Contributed=20under=20Me?= =?UTF-8?q?taX=20Integrated=20Circuits=20(Shanghai)=20Co.,=20Ltd.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/module_basis/module_pw/fft.cpp | 189 +++++++++- source/module_basis/module_pw/fft.h | 81 ++++- .../module_pw/kernels/cuda/pw_op.cu | 140 ++++++++ .../module_basis/module_pw/kernels/pw_op.cpp | 69 ++++ source/module_basis/module_pw/kernels/pw_op.h | 128 +++++++ .../module_pw/kernels/rocm/pw_op.hip.cu | 140 ++++++++ source/module_basis/module_pw/pw_basis.h | 4 + source/module_basis/module_pw/pw_basis_k.cpp | 14 +- source/module_basis/module_pw/pw_basis_k.h | 28 +- .../module_basis/module_pw/pw_transform_k.cpp | 324 ++++++++++++++++-- source/module_elecstate/elecstate_pw.cpp | 85 +++-- source/module_elecstate/elecstate_pw.h | 3 + .../kernels/cuda/elecstate_op.cu | 73 ++++ .../module_elecstate/kernels/elecstate_op.cpp | 25 ++ .../module_elecstate/kernels/elecstate_op.h | 49 ++- .../kernels/rocm/elecstate_op.hip.cu | 73 ++++ .../hamilt_pwdft/kernels/cuda/veff_op.cu | 86 +++++ .../hamilt_pwdft/kernels/rocm/veff_op.hip.cu | 86 +++++ .../hamilt_pwdft/kernels/veff_op.cpp | 36 ++ .../hamilt_pwdft/kernels/veff_op.h | 72 ++++ .../hamilt_pwdft/operator_pw/veff_pw.cpp | 81 ++++- .../hamilt_pwdft/operator_pw/veff_pw.h | 21 +- source/module_io/read_input_item_general.cpp | 6 + source/module_parameter/input_parameter.h | 4 +- 24 files changed, 1734 insertions(+), 83 deletions(-) diff --git a/source/module_basis/module_pw/fft.cpp b/source/module_basis/module_pw/fft.cpp index 1c56f9b5af7..41a787f5cb1 100644 --- a/source/module_basis/module_pw/fft.cpp +++ b/source/module_basis/module_pw/fft.cpp @@ -3,6 +3,7 @@ #include "module_base/memory.h" #include "module_base/tool_quit.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_parameter/parameter.h" namespace ModulePW { @@ -808,7 +809,8 @@ void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex -void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, +void FFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, + std::complex* in, std::complex* out) const { #if defined(__CUDA) @@ -877,4 +879,189 @@ void FFT::set_precision(std::string precision_) this->precision = std::move(precision_); } +#if defined(__CUDA) || defined(__ROCM) +template +BatchedFFT::BatchedFFT(int nx_, int ny_, int nz_): nx(nx_), ny(ny_), nz(nz_) +{ +} + +template +BatchedFFT::BatchedFFT(): nx(0), ny(0), nz(0) +{ +} + +template +void BatchedFFT::initFFT(int nx_, int ny_, int nz_) +{ + nx = nx_; + ny = ny_; + nz = nz_; + this->cleanFFT(); + this->clear_data(); +} + +template +BatchedFFT::~BatchedFFT() +{ + this->cleanFFT(); + this->clear_data(); +} + +template +void BatchedFFT::cleanFFT() +{ + for (auto& pair : this->plans) { +#if defined(__CUDA) + CHECK_CUFFT(cufftDestroy(pair.second)); +#elif defined(__ROCM) + CHECK_CUFFT(hipfftDestroy(pair.second)); +#endif + } +} + +template +void BatchedFFT::clear_data() const +{ + if (this->auxr_3d){ + base_device::memory::delete_memory_op, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d); + this->auxr_3d = nullptr; + this->auxr_3d_size = 0; + } + + if (this->sharedWorkArea){ + base_device::memory::delete_memory_op()(gpu_ctx, this->sharedWorkArea); + this->sharedWorkArea = nullptr; + this->sharedWorkAreaSize = 0; + } +} + +template +std::complex* BatchedFFT::get_auxr_3d_data(int batchSize) const +{ + if(this->auxr_3d_size >= sizeof(std::complex) * batchSize * this->nx * this->ny * this->nz && this->auxr_3d != nullptr){ + return this->auxr_3d; + } + + base_device::memory::resize_memory_op, base_device::DEVICE_GPU>()(gpu_ctx, this->auxr_3d, this->nx * this->ny * this->nz * batchSize); + this->auxr_3d_size = sizeof(std::complex) * batchSize * this->nx * this->ny * this->nz; + return this->auxr_3d; +} + +template +void BatchedFFT::fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, std::complex* out, const int batchSize)const +{ +#if defined(__CUDA) + cufftHandle plan = this->get_plan_from_cache(batchSize); + if (this->fftType == CUFFT_C2C){ + CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast(in), reinterpret_cast(out), + CUFFT_FORWARD)); + }else{ + CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast(in), reinterpret_cast(out), + CUFFT_FORWARD)); + } +#elif defined(__ROCM) + hipfftHandle plan = this->get_plan_from_cache(batchSize); + if (this->fftType == HIPFFT_C2C){ + CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast(in), reinterpret_cast(out), + HIPFFT_FORWARD)); + }else{ + CHECK_CUFFT(hipfftExecZ2Z(plan, reinterpret_cast(in), reinterpret_cast(out), + HIPFFT_FORWARD)); + } +#endif +} + +template +void BatchedFFT::fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, std::complex* out, const int batchSize)const +{ +#if defined(__CUDA) + cufftHandle plan = this->get_plan_from_cache(batchSize); + if (this->fftType == CUFFT_C2C){ + CHECK_CUFFT(cufftExecC2C(plan, reinterpret_cast(in), reinterpret_cast(out), + CUFFT_INVERSE)); + }else{ + CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast(in), reinterpret_cast(out), + CUFFT_INVERSE)); + } +#elif defined(__ROCM) + hipfftHandle plan = this->get_plan_from_cache(batchSize); + if (this->fftType == HIPFFT_C2C){ + CHECK_CUFFT(hipfftExecC2C(plan, reinterpret_cast(in), reinterpret_cast(out), + HIPFFT_INVERSE)); + }else{ + CHECK_CUFFT(cufftExecZ2Z(plan, reinterpret_cast(in), reinterpret_cast(out), + HIPFFT_INVERSE)); + } +#endif + +} + +template +int BatchedFFT::estimate_batch_size(size_t addtional_memory) +{ + int input_batchSize = PARAM.inp.fft_batch_size; + if (input_batchSize == 0){ + size_t free_mem, total_mem; +#if defined(__CUDA) + cudaMemGetInfo(&free_mem, &total_mem); +#elif defined(__ROCM) + hipMemGetInfo(&free_mem, &total_mem); +#endif + input_batchSize = free_mem * FREE_MEM_COEFF_FFT / addtional_memory; + } + return std::max(1, std::min(MAX_BATCH_SIZE_FFT, input_batchSize)); +} + +template +typename BatchedFFT::fftHandleType BatchedFFT::get_plan_from_cache(int batchSize) const +{ + auto it = this->plans.find(batchSize); + if (it != this->plans.end()) { + return it->second; + } + + fftHandleType plan; + int rank = 3; + int n[3] = {this->nx, this->ny, this->nz}; + int *inembed = nullptr; + int *onembed = nullptr; + int istride = 1, ostride = 1; + int idist = this->nx * this->ny * this->nz; + int odist = idist; + + size_t workAreaSize; + +#if defined(__CUDA) + CHECK_CUFFT(cufftPlanMany(&plan, rank, n, inembed, istride, idist, + onembed, ostride, odist, + fftType, batchSize)); + CHECK_CUFFT(cufftGetSize(plan, &workAreaSize)); +#elif defined(__ROCM) + CHECK_CUFFT(hipfftPlanMany(&plan, rank, n, inembed, istride, idist, + onembed, ostride, odist, + fftType, batchSize)); + CHECK_CUFFT(hipfftGetSize(plan, &workAreaSize)); +#endif + + if (workAreaSize >= this->sharedWorkAreaSize && workAreaSize > 0){ + base_device::memory::resize_memory_op()(gpu_ctx, this->sharedWorkArea, workAreaSize); + this->sharedWorkAreaSize = workAreaSize; + } + + if (workAreaSize > 0){ +#if defined(__CUDA) + CHECK_CUFFT(cufftSetWorkArea(plan, this->sharedWorkArea)); +#elif defined(__ROCM) + CHECK_CUFFT(hipfftSetWorkArea(plan, this->sharedWorkArea)); +#endif + } + + this->plans[batchSize] = plan; + return plan; +} + +template class BatchedFFT; +template class BatchedFFT; + +#endif } // namespace ModulePW diff --git a/source/module_basis/module_pw/fft.h b/source/module_basis/module_pw/fft.h index 3581d01d186..1c2c0f24563 100644 --- a/source/module_basis/module_pw/fft.h +++ b/source/module_basis/module_pw/fft.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include "fftw3.h" #if defined(__FFTW3_MPI) && defined(__MPI) @@ -40,13 +42,13 @@ class FFT FFT(); ~FFT(); void clear(); //reset fft - + // init parameters of fft - void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, + void initfft(int nx_in, int ny_in, int nz_in, int lixy_in, int rixy_in, int ns_in, int nplane_in, int nproc_in, bool gamma_only_in, bool xprime_in = true, bool mpifft_in = false); //init fftw_plans - void setupFFT(); + void setupFFT(); //destroy fftw_plans void cleanFFT(); @@ -106,7 +108,7 @@ public : template std::complex* get_auxr_3d_data() const; - int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive + int fft_mode = 0; ///< fftw mode 0: estimate, 1: measure, 2: patient, 3: exhaustive private: bool gamma_only = false; @@ -167,6 +169,77 @@ public : void set_precision(std::string precision_); }; + +#if defined(__CUDA) || defined(__ROCM) +template +struct FFTTypeTraits; + +template <> +struct FFTTypeTraits { +#if defined(__CUDA) + using cuComplexType = cufftComplex; + static constexpr cufftType Type = CUFFT_C2C; +#elif defined(__ROCM) + using hipComplexType = hipfftComplex; + static constexpr hipfftType Type = HIPFFT_C2C; +#endif +}; + +template <> +struct FFTTypeTraits { +#if defined(__CUDA) + using cuComplexType = cufftDoubleComplex; + static constexpr cufftType Type = CUFFT_Z2Z; +#elif defined(__ROCM) + using hipComplexType = hipfftDoubleComplex; + static constexpr hipfftType Type = HIPFFT_Z2Z; +#endif +}; + +constexpr float FREE_MEM_COEFF_FFT = 0.8; +constexpr int MAX_BATCH_SIZE_FFT = 32; + +template // float or double +class BatchedFFT +{ +public: +#if defined(__CUDA) + using cuComplexType = typename FFTTypeTraits::cuComplexType; + static constexpr cufftType fftType = FFTTypeTraits::Type; + using fftHandleType = cufftHandle; +#elif defined(__ROCM) + using hipComplexType = typename FFTTypeTraits::hipComplexType; + static constexpr hipfftType fftType = FFTTypeTraits::Type; + using fftHandleType = hipfftHandle; +#endif + + BatchedFFT(int nx_, int ny_, int nz_); + BatchedFFT(); + ~BatchedFFT(); + void initFFT(int nx_, int ny_, int nz_); + void cleanFFT(); + void clear_data() const; +public: + std::complex* get_auxr_3d_data(int batchSize)const; + void fft3D_forward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, std::complex* out, const int batchSize)const; + void fft3D_backward(const base_device::DEVICE_GPU* /*ctx*/, std::complex* in, std::complex* out, const int batchSize)const; + static int estimate_batch_size(size_t addtional_memory); + +private: + int nx, ny, nz; + mutable std::unordered_map plans; + + mutable std::complex *auxr_3d = nullptr; // fft space + mutable ::size_t auxr_3d_size = 0; + mutable char *sharedWorkArea = nullptr; + mutable size_t sharedWorkAreaSize = 0; + + fftHandleType get_plan_from_cache(int batchSize)const; + +}; + +#endif + } #endif diff --git a/source/module_basis/module_pw/kernels/cuda/pw_op.cu b/source/module_basis/module_pw/kernels/cuda/pw_op.cu index a9128db3187..cedb9afe371 100644 --- a/source/module_basis/module_pw/kernels/cuda/pw_op.cu +++ b/source/module_basis/module_pw/kernels/cuda/pw_op.cu @@ -23,6 +23,25 @@ __global__ void set_3d_fft_box( } } +template +__global__ void set_3d_fft_box_batch( + const int npwk, + const int* box_index, + const thrust::complex* in, + const int ld_in, + thrust::complex* out, + const int ld_out, + const int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < npwk && batch < batchSize) + { + int xx = box_index[idx]; + out[batch * ld_out + xx] = in[batch * ld_in + idx]; + } +} + template __global__ void set_recip_to_real_output( const int nrxx, @@ -41,6 +60,28 @@ __global__ void set_recip_to_real_output( } } +template +__global__ void set_recip_to_real_output_batch( + const int nrxx, + const bool add, + const FPTYPE factor, + const thrust::complex* in, + const int ld_in, + thrust::complex* out, + const int ld_out, + const int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= nrxx || batch >= batchSize) {return;} + if(add) { + out[batch * ld_out + idx] += factor * in[batch * ld_in + idx]; + } + else { + out[batch * ld_out + idx] = in[batch * ld_in + idx]; + } +} + template __global__ void set_real_to_recip_output( const int npwk, @@ -61,6 +102,30 @@ __global__ void set_real_to_recip_output( } } +template +__global__ void set_real_to_recip_output_batch( + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const thrust::complex* in, + const int ld_in, + thrust::complex* out, + const int ld_out, + const int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= npwk || batch >= batchSize) {return;} + if(add) { + out[batch * ld_out + idx] += factor / nxyz * in[batch * ld_in + box_index[idx]]; + } + else { + out[batch * ld_out + idx] = in[batch * ld_in + box_index[idx]] / nxyz; + } +} + template void set_3d_fft_box_op::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -78,6 +143,26 @@ void set_3d_fft_box_op::operator()(const base_d cudaCheckOnDebug(); } +template +void set_3d_fft_box_batch_op::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int npwk, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) +{ + dim3 block((npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + set_3d_fft_box_batch<<>>( + npwk, + box_index, + reinterpret_cast*>(in), ld_in, + reinterpret_cast*>(out), ld_out, batchSize); + + cudaCheckOnDebug(); +} + template void set_recip_to_real_output_op::operator()(const base_device::DEVICE_GPU* /*dev*/, const int nrxx, @@ -97,6 +182,29 @@ void set_recip_to_real_output_op::operator()(co cudaCheckOnDebug(); } +template +void set_recip_to_real_output_batch_op::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) +{ + + dim3 block((nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + set_recip_to_real_output_batch<<>>( + nrxx, + add, + factor, + reinterpret_cast*>(in), ld_in, + reinterpret_cast*>(out), ld_out, batchSize); + + cudaCheckOnDebug(); +} + template void set_real_to_recip_output_op::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -120,11 +228,43 @@ void set_real_to_recip_output_op::operator()(co cudaCheckOnDebug(); } +template +void set_real_to_recip_output_batch_op::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) +{ + dim3 block((npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + set_real_to_recip_output_batch<<>>( + npwk, + nxyz, + add, + factor, + box_index, + reinterpret_cast*>(in), ld_in, + reinterpret_cast*>(out), ld_out, batchSize); + + cudaCheckOnDebug(); +} + template struct set_3d_fft_box_op; +template struct set_3d_fft_box_batch_op; template struct set_recip_to_real_output_op; +template struct set_recip_to_real_output_batch_op; template struct set_real_to_recip_output_op; +template struct set_real_to_recip_output_batch_op; template struct set_3d_fft_box_op; +template struct set_3d_fft_box_batch_op; template struct set_recip_to_real_output_op; +template struct set_recip_to_real_output_batch_op; template struct set_real_to_recip_output_op; +template struct set_real_to_recip_output_batch_op; } // namespace ModulePW diff --git a/source/module_basis/module_pw/kernels/pw_op.cpp b/source/module_basis/module_pw/kernels/pw_op.cpp index b5fb4533540..560b1a1c319 100644 --- a/source/module_basis/module_pw/kernels/pw_op.cpp +++ b/source/module_basis/module_pw/kernels/pw_op.cpp @@ -18,6 +18,25 @@ struct set_3d_fft_box_op } }; +template +struct set_3d_fft_box_batch_op +{ + void operator()(const base_device::DEVICE_CPU* dev, + const int npwk, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) + { + for (int i = 0; i < batchSize; ++i) + { + set_3d_fft_box_op()(dev, npwk, box_index, in + i * ld_in, out + i * ld_out); + } + } +}; + template struct set_recip_to_real_output_op { @@ -41,6 +60,27 @@ struct set_recip_to_real_output_op } }; +template +struct set_recip_to_real_output_batch_op +{ + void operator()(const base_device::DEVICE_CPU* dev, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) + { + for (int i = 0; i < batchSize; ++i) + { + set_recip_to_real_output_op()(dev, nrxx, add, factor, + in + i * ld_in, out + i * ld_out); + } + } +}; + template struct set_real_to_recip_output_op { @@ -66,12 +106,41 @@ struct set_real_to_recip_output_op } }; +template +struct set_real_to_recip_output_batch_op +{ + void operator()(const base_device::DEVICE_CPU* dev, + const int npw_k, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) + { + for (int i = 0; i < batchSize; ++i) + { + set_real_to_recip_output_op()(dev, npw_k, nxyz, add, factor, + box_index, in + i * ld_in, out + i * ld_out); + } + } +}; + template struct set_3d_fft_box_op; +template struct set_3d_fft_box_batch_op; template struct set_recip_to_real_output_op; +template struct set_recip_to_real_output_batch_op; template struct set_real_to_recip_output_op; +template struct set_real_to_recip_output_batch_op; template struct set_3d_fft_box_op; +template struct set_3d_fft_box_batch_op; template struct set_recip_to_real_output_op; +template struct set_recip_to_real_output_batch_op; template struct set_real_to_recip_output_op; +template struct set_real_to_recip_output_batch_op; } // namespace ModulePW diff --git a/source/module_basis/module_pw/kernels/pw_op.h b/source/module_basis/module_pw/kernels/pw_op.h index 8415ad96778..8f71bbdfc56 100644 --- a/source/module_basis/module_pw/kernels/pw_op.h +++ b/source/module_basis/module_pw/kernels/pw_op.h @@ -26,6 +26,33 @@ struct set_3d_fft_box_op { std::complex* out); }; +template +struct set_3d_fft_box_batch_op { + /// @brief Set the 3D fft box for fft transfrom between the recip and real space. + /// To map the 1D psi(1D continuous array) to 3D box psi(fft box) + /// + /// Input Parameters + /// @param dev - which device this function runs on + /// @param npwk - number of planwaves + /// @param box_index - the mapping function of 1D to 3D + /// @param in - input psi within a 1D array(in recip space) + /// @param ld_in - leading dimension of in array + /// @param ld_out - leading dimension of out array + /// @param batchSize - batch size for one calculation + /// + /// Output Parameters + /// @param out - output psi within the 3D box(in recip space) + void operator() ( + const Device* dev, + const int npwk, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize); +}; + template struct set_recip_to_real_output_op { /// @brief Calculate the outputs after the FFT translation of recip_to_real @@ -47,6 +74,33 @@ struct set_recip_to_real_output_op { std::complex* out); }; +template +struct set_recip_to_real_output_batch_op { + /// @brief Calculate the outputs after the FFT translation of recip_to_real + /// + /// Input Parameters + /// @param dev - which device this function runs on + /// @param nrxx - size of array + /// @param add - flag to control whether to add the input itself + /// @param in - input psi within a 1D array(in real space) + /// @param ld_in - leading dimension of in array + /// @param ld_out - leading dimension of out array + /// @param batchSize - batch size for one calculation + /// + /// Output Parameters + /// @param out - output psi within the 3D box(in real space) + void operator() ( + const Device* dev, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize); +}; + template struct set_real_to_recip_output_op { /// @brief Calculate the outputs after the FFT translation of real_to_recip @@ -72,6 +126,37 @@ struct set_real_to_recip_output_op { std::complex* out); }; +template +struct set_real_to_recip_output_batch_op { + /// @brief Calculate the outputs after the FFT translation of real_to_recip + /// + /// Input Parameters + /// @param dev - which device this function runs on + /// @param nxyz - size of array + /// @param add - flag to control whether to add the input itself + /// @param factor - input constant value + /// @param box_index - input box parameters + /// @param in - input psi within a 1D array(in recip space) + /// @param ld_in - leading dimension of in array + /// @param ld_out - leading dimension of out array + /// @param batchSize - batch size for one calculation + /// + /// Output Parameters + /// @param out - output psi within the 3D box(in recip space) + void operator() ( + const Device* dev, + const int npw_k, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize); +}; + #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM // Partially specialize functor for base_device::GpuDevice. template @@ -84,6 +169,19 @@ struct set_3d_fft_box_op std::complex* out); }; +template +struct set_3d_fft_box_batch_op +{ + void operator()(const base_device::DEVICE_GPU* dev, + const int npwk, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize); +}; + template struct set_recip_to_real_output_op { @@ -95,6 +193,20 @@ struct set_recip_to_real_output_op std::complex* out); }; +template +struct set_recip_to_real_output_batch_op { + void operator()(const base_device::DEVICE_GPU* dev, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize); +}; + + template struct set_real_to_recip_output_op { @@ -108,6 +220,22 @@ struct set_real_to_recip_output_op std::complex* out); }; +template +struct set_real_to_recip_output_batch_op { + void operator() ( + const base_device::DEVICE_GPU* dev, + const int npw_k, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize); +}; + #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM } // namespace ModulePW #endif //MODULE_PW_MULTI_DEVICE_H \ No newline at end of file diff --git a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu index a3f5fe2c2bd..1b462ac66da 100644 --- a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu +++ b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu @@ -24,6 +24,25 @@ __global__ void set_3d_fft_box( } } +template +__global__ void set_3d_fft_box_batch( + const int npwk, + const int* box_index, + const thrust::complex* in, + const int ld_in, + thrust::complex* out, + const int ld_out, + const int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx < npwk && batch < batchSize) + { + int xx = box_index[idx]; + out[batch * ld_out + xx] = in[batch * ld_in + idx]; + } +} + template __global__ void set_recip_to_real_output( const int nrxx, @@ -42,6 +61,28 @@ __global__ void set_recip_to_real_output( } } +template +__global__ void set_recip_to_real_output_batch( + const int nrxx, + const bool add, + const FPTYPE factor, + const thrust::complex* in, + const int ld_in, + thrust::complex* out, + const int ld_out, + const int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= nrxx || batch >= batchSize) {return;} + if(add) { + out[batch * ld_out + idx] += factor * in[batch * ld_in + idx]; + } + else { + out[batch * ld_out + idx] = in[batch * ld_in + idx]; + } +} + template __global__ void set_real_to_recip_output( const int npwk, @@ -62,6 +103,30 @@ __global__ void set_real_to_recip_output( } } +template +__global__ void set_real_to_recip_output_batch( + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const thrust::complex* in, + const int ld_in, + thrust::complex* out, + const int ld_out, + const int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= npwk || batch >= batchSize) {return;} + if(add) { + out[batch * ld_out + idx] += factor / nxyz * in[batch * ld_in + box_index[idx]]; + } + else { + out[batch * ld_out + idx] = in[batch * ld_in + box_index[idx]] / nxyz; + } +} + template void set_3d_fft_box_op::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -79,6 +144,26 @@ void set_3d_fft_box_op::operator()(const base_d hipCheckOnDebug(); } +template +void set_3d_fft_box_batch_op::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int npwk, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) +{ + dim3 block((npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + hipLaunchKernelGGL(HIP_KERNEL_NAME(set_3d_fft_box_batch), block, dim3(THREADS_PER_BLOCK), 0, 0, + npwk, + box_index, + reinterpret_cast*>(in), ld_in, + reinterpret_cast*>(out), ld_out, batchSize); + + hipCheckOnDebug(); +} + template void set_recip_to_real_output_op::operator()(const base_device::DEVICE_GPU* /*dev*/, const int nrxx, @@ -98,6 +183,29 @@ void set_recip_to_real_output_op::operator()(co hipCheckOnDebug(); } +template +void set_recip_to_real_output_batch_op::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) +{ + + dim3 block((nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + hipLaunchKernelGGL(HIP_KERNEL_NAME(set_recip_to_real_output_batch), block, dim3(THREADS_PER_BLOCK), 0, 0, + nrxx, + add, + factor, + reinterpret_cast*>(in), ld_in, + reinterpret_cast*>(out), ld_out, batchSize); + + hipCheckOnDebug(); +} + template void set_real_to_recip_output_op::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -121,12 +229,44 @@ void set_real_to_recip_output_op::operator()(co hipCheckOnDebug(); } +template +void set_real_to_recip_output_batch_op::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int batchSize) +{ + dim3 block((npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + hipLaunchKernelGGL(HIP_KERNEL_NAME(set_real_to_recip_output_batch), block, dim3(THREADS_PER_BLOCK), 0, 0, + npwk, + nxyz, + add, + factor, + box_index, + reinterpret_cast*>(in), ld_in, + reinterpret_cast*>(out), ld_out, batchSize); + + hipCheckOnDebug(); +} + template struct set_3d_fft_box_op; +template struct set_3d_fft_box_batch_op; template struct set_recip_to_real_output_op; +template struct set_recip_to_real_output_batch_op; template struct set_real_to_recip_output_op; +template struct set_real_to_recip_output_batch_op; template struct set_3d_fft_box_op; +template struct set_3d_fft_box_batch_op; template struct set_recip_to_real_output_op; +template struct set_recip_to_real_output_batch_op; template struct set_real_to_recip_output_op; +template struct set_real_to_recip_output_batch_op; } // namespace ModulePW diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 6f95343b1a8..3f914803a7f 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -242,6 +242,10 @@ class PW_Basis int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx // Thus complex[nmaxgr] is able to contain either reciprocal or real data FFT ft; +#if defined(__CUDA) || defined(__ROCM) + BatchedFFT batched_ft_float; + BatchedFFT batched_ft_double; +#endif //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). template diff --git a/source/module_basis/module_pw/pw_basis_k.cpp b/source/module_basis/module_pw/pw_basis_k.cpp index 02d5614e4b4..9d2ab853b86 100644 --- a/source/module_basis/module_pw/pw_basis_k.cpp +++ b/source/module_basis/module_pw/pw_basis_k.cpp @@ -84,7 +84,7 @@ void PW_Basis_K:: initparameters( this->xprime = xprime_in; this->fftny = this->ny; this->fftnx = this->nx; - if (this->gamma_only) + if (this->gamma_only) { if(this->xprime) this->fftnx = int(this->nx / 2) + 1; else this->fftny = int(this->ny / 2) + 1; @@ -125,7 +125,7 @@ void PW_Basis_K::setupIndGk() int ng = 0; for (int ig = 0; ig < this->npw ; ig++) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); + const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { ++ng; @@ -138,7 +138,7 @@ void PW_Basis_K::setupIndGk() this->npwk_max = ng; } } - + //get igl2isz_k and igl2ig_k if(this->npwk_max <= 0) return; @@ -149,7 +149,7 @@ void PW_Basis_K::setupIndGk() int igl = 0; for (int ig = 0; ig < this->npw ; ig++) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); + const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { this->igl2isz_k[ik*npwk_max + igl] = this->ig2isz[ig]; @@ -167,7 +167,7 @@ void PW_Basis_K::setupIndGk() return; } -/// +/// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall /// set up ffts @@ -183,6 +183,10 @@ void PW_Basis_K::setuptransform() if(this->xprime) this->ft.initfft(this->nx,this->ny,this->nz,this->lix,this->rix,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); else this->ft.initfft(this->nx,this->ny,this->nz,this->liy,this->riy,this->nst,this->nplane,this->poolnproc,this->gamma_only, this->xprime); this->ft.setupFFT(); +#if defined(__CUDA) || defined(__ROCM) + this->batched_ft_float.initFFT(this->nx, this->ny, this->nz); + this->batched_ft_double.initFFT(this->nx, this->ny, this->nz); +#endif ModuleBase::timer::tick(this->classname, "setuptransform"); } diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index 83aa377e792..491f4e5e6b8 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -113,13 +113,13 @@ class PW_Basis_K : public PW_Basis public: template - void real2recip(const FPTYPE* in, + void real2recip(FPTYPE* in, std::complex* out, const int ik, const bool add = false, const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) template - void real2recip(const std::complex* in, + void real2recip(std::complex* in, std::complex* out, const int ik, const bool add = false, @@ -139,7 +139,7 @@ class PW_Basis_K : public PW_Basis template void real_to_recip(const Device* ctx, - const std::complex* in, + std::complex* in, std::complex* out, const int ik, const bool add = false, @@ -152,6 +152,28 @@ class PW_Basis_K : public PW_Basis const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) +#if defined(__CUDA) || defined(__ROCM) + template + void real_to_recip_batch(const Device* ctx, + std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template + void recip_to_real_batch(const Device* ctx, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add = false, + const FPTYPE factor = 1.0)const; // in:(nz, ns) ; out(nplane,nx*ny) +#endif public: //operator: //get (G+K)^2: diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 0ea362825b6..896a4a3e323 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -23,7 +23,7 @@ namespace ModulePW * @param out: (nz, ns), complex data */ template -void PW_Basis_K::real2recip(const std::complex* in, +void PW_Basis_K::real2recip(std::complex* in, std::complex* out, const int ik, const bool add, @@ -88,7 +88,7 @@ void PW_Basis_K::real2recip(const std::complex* in, * @param out: (nz, ns), complex data */ template -void PW_Basis_K::real2recip(const FPTYPE* in, +void PW_Basis_K::real2recip(FPTYPE* in, std::complex* out, const int ik, const bool add, @@ -291,7 +291,7 @@ void PW_Basis_K::recip2real(const std::complex* in, template <> void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, - const std::complex* in, + std::complex* in, std::complex* out, const int ik, const bool add, @@ -301,7 +301,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, } template <> void PW_Basis_K::real_to_recip(const base_device::DEVICE_CPU* /*dev*/, - const std::complex* in, + std::complex* in, std::complex* out, const int ik, const bool add, @@ -331,10 +331,78 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_CPU* /*dev*/, this->recip2real(in, out, ik, add, factor); } +template <> +void PW_Basis_K::real_to_recip_batch(const base_device::DEVICE_CPU* ctx, + std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const float factor) const +{ + for (int i = 0; i < batchSize; ++i) + { + this->real_to_recip(ctx, in + ld_in * i, out + ld_out * i, ik, add, factor); + } +} + +template <> +void PW_Basis_K::real_to_recip_batch(const base_device::DEVICE_CPU* ctx, + std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const double factor) const +{ + for (int i = 0; i < batchSize; ++i) + { + this->real_to_recip(ctx, in + ld_in * i, out + ld_out * i, ik, add, factor); + } +} + +template <> +void PW_Basis_K::recip_to_real_batch(const base_device::DEVICE_CPU* ctx, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const float factor)const +{ + for (int i = 0; i < batchSize; ++i) + { + this->recip_to_real(ctx, in + ld_in * i, out + ld_out * i, ik, add, factor); + } +} +template <> +void PW_Basis_K::recip_to_real_batch(const base_device::DEVICE_CPU* ctx, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const double factor)const +{ + for (int i = 0; i < batchSize; ++i) + { + this->recip_to_real(ctx, in + ld_in * i, out + ld_out * i, ik, add, factor); + } +} + + #if (defined(__CUDA) || defined(__ROCM)) template <> void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, - const std::complex* in, + std::complex* in, std::complex* out, const int ik, const bool add, @@ -344,14 +412,14 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, assert(this->gamma_only == false); assert(this->poolnproc == 1); - base_device::memory::synchronize_memory_op, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()( - ctx, - ctx, - this->ft.get_auxr_3d_data(), - in, - this->nrxx); + // base_device::memory::synchronize_memory_op, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()( + // ctx, + // ctx, + // this->ft.get_auxr_3d_data(), + // in, + // this->nrxx); - this->ft.fft3D_forward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + this->ft.fft3D_forward(ctx, in, this->ft.get_auxr_3d_data()); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; @@ -367,7 +435,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, } template <> void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, - const std::complex* in, + std::complex* in, std::complex* out, const int ik, const bool add, @@ -377,15 +445,15 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, assert(this->gamma_only == false); assert(this->poolnproc == 1); - base_device::memory::synchronize_memory_op, - base_device::DEVICE_GPU, - base_device::DEVICE_GPU>()(ctx, - ctx, - this->ft.get_auxr_3d_data(), - in, - this->nrxx); + // base_device::memory::synchronize_memory_op, + // base_device::DEVICE_GPU, + // base_device::DEVICE_GPU>()(ctx, + // ctx, + // this->ft.get_auxr_3d_data(), + // in, + // this->nrxx); - this->ft.fft3D_forward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + this->ft.fft3D_forward(ctx, in, this->ft.get_auxr_3d_data()); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; @@ -426,14 +494,19 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, this->ig2ixyz_k + startig, in, this->ft.get_auxr_3d_data()); - this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + if (add){ + this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + + set_recip_to_real_output_op()(ctx, + this->nrxx, + add, + factor, + this->ft.get_auxr_3d_data(), + out); + }else{ + this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), out); + } - set_recip_to_real_output_op()(ctx, - this->nrxx, - add, - factor, - this->ft.get_auxr_3d_data(), - out); ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); } @@ -463,25 +536,202 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, this->ig2ixyz_k + startig, in, this->ft.get_auxr_3d_data()); - this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + if (add){ + this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), this->ft.get_auxr_3d_data()); + + set_recip_to_real_output_op()(ctx, + this->nrxx, + add, + factor, + this->ft.get_auxr_3d_data(), + out); + }else{ + this->ft.fft3D_backward(ctx, this->ft.get_auxr_3d_data(), out); + } + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); +} + +template <> +void PW_Basis_K::real_to_recip_batch(const base_device::DEVICE_GPU* ctx, + std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const float factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip_batch gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + std::complex *fft_data = this->batched_ft_float.get_auxr_3d_data(batchSize); + + // base_device::memory::synchronize_memory_op, + // base_device::DEVICE_GPU, + // base_device::DEVICE_GPU>()(ctx, + // ctx, + // fft_data, + // in, + // this->nrxx * batchSize); + + this->batched_ft_float.fft3D_forward(ctx, in, fft_data, batchSize); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + set_real_to_recip_output_batch_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k + startig, + fft_data, this->nxyz, + out, ld_out, batchSize); + ModuleBase::timer::tick(this->classname, "real_to_recip_batch gpu"); +} + +template <> +void PW_Basis_K::real_to_recip_batch(const base_device::DEVICE_GPU* ctx, + std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const double factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip_batch gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + std::complex *fft_data = this->batched_ft_double.get_auxr_3d_data(batchSize); + + // base_device::memory::synchronize_memory_op, + // base_device::DEVICE_GPU, + // base_device::DEVICE_GPU>()(ctx, + // ctx, + // fft_data, + // in, + // this->nrxx * batchSize); + + this->batched_ft_double.fft3D_forward(ctx, in, fft_data, batchSize); - set_recip_to_real_output_op()(ctx, + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + set_real_to_recip_output_batch_op()(ctx, + npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k + startig, + fft_data, this->nxyz, + out, ld_out, batchSize); + ModuleBase::timer::tick(this->classname, "real_to_recip_batch gpu"); + +} + +template <> +void PW_Basis_K::recip_to_real_batch(const base_device::DEVICE_GPU* ctx, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const float factor)const +{ + ModuleBase::timer::tick(this->classname, "recip_to_real_batch gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + // ModuleBase::GlobalFunc::ZEROS(ft.get_auxr_3d_data(), this->nxyz); + std::complex *fft_data = this->batched_ft_float.get_auxr_3d_data(batchSize); + base_device::memory::set_memory_op, base_device::DEVICE_GPU>()( + ctx, + fft_data, + 0, + this->nxyz * batchSize); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + + set_3d_fft_box_batch_op()(ctx, + npw_k, + this->ig2ixyz_k + startig, + in, ld_in, + fft_data, this->nxyz, batchSize); + + if (add){ + this->batched_ft_float.fft3D_backward(ctx, fft_data, fft_data, batchSize); + set_recip_to_real_output_batch_op()(ctx, this->nrxx, add, factor, - this->ft.get_auxr_3d_data(), - out); + fft_data, this->nxyz, + out, ld_out, batchSize); + }else{ + this->batched_ft_float.fft3D_backward(ctx, fft_data, out, batchSize); + } - ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); + + ModuleBase::timer::tick(this->classname, "recip_to_real_batch gpu"); } + +template <> +void PW_Basis_K::recip_to_real_batch(const base_device::DEVICE_GPU* ctx, + const std::complex* in, + const int ld_in, + std::complex* out, + const int ld_out, + const int ik, + const int batchSize, + const bool add, + const double factor)const +{ + ModuleBase::timer::tick(this->classname, "recip_to_real_batch gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + // ModuleBase::GlobalFunc::ZEROS(ft.get_auxr_3d_data(), this->nxyz); + std::complex *fft_data = this->batched_ft_double.get_auxr_3d_data(batchSize); + base_device::memory::set_memory_op, base_device::DEVICE_GPU>()( + ctx, + fft_data, + 0, + this->nxyz * batchSize); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + + set_3d_fft_box_batch_op()(ctx, + npw_k, + this->ig2ixyz_k + startig, + in, ld_in, + fft_data, this->nxyz, batchSize); + + if(add){ + this->batched_ft_double.fft3D_backward(ctx, fft_data, fft_data, batchSize); + set_recip_to_real_output_batch_op()(ctx, + this->nrxx, + add, + factor, + fft_data, this->nxyz, + out, ld_out, batchSize); + }else{ + this->batched_ft_double.fft3D_backward(ctx, fft_data, out, batchSize); + } + + + ModuleBase::timer::tick(this->classname, "recip_to_real_batch gpu"); +} + #endif -template void PW_Basis_K::real2recip(const float* in, +template void PW_Basis_K::real2recip(float* in, std::complex* out, const int ik, const bool add, const float factor) const; // in:(nplane,nx*ny) ; out(nz, ns) -template void PW_Basis_K::real2recip(const std::complex* in, +template void PW_Basis_K::real2recip(std::complex* in, std::complex* out, const int ik, const bool add, @@ -497,12 +747,12 @@ template void PW_Basis_K::recip2real(const std::complex* in, const bool add, const float factor) const; // in:(nz, ns) ; out(nplane,nx*ny) -template void PW_Basis_K::real2recip(const double* in, +template void PW_Basis_K::real2recip(double* in, std::complex* out, const int ik, const bool add, const double factor) const; // in:(nplane,nx*ny) ; out(nz, ns) -template void PW_Basis_K::real2recip(const std::complex* in, +template void PW_Basis_K::real2recip(std::complex* in, std::complex* out, const int ik, const bool add, diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 0b6488092ea..e4b33fd2361 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -31,7 +31,7 @@ ElecStatePW::ElecStatePW(ModulePW::PW_Basis_K* wfc_basis_in, } template -ElecStatePW::~ElecStatePW() +ElecStatePW::~ElecStatePW() { if (base_device::get_device_type(this->ctx) == base_device::GpuDevice) { @@ -47,7 +47,7 @@ ElecStatePW::~ElecStatePW() } template -void ElecStatePW::init_rho_data() +void ElecStatePW::init_rho_data() { if (GlobalV::device_flag == "gpu" || GlobalV::precision_flag == "single") { this->rho = new Real*[this->charge->nspin]; @@ -171,30 +171,54 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) if (GlobalV::NSPIN == 4) { int npwx = npw / 2; - for (int ibnd = 0; ibnd < nbands; ibnd++) + + // additional memeory : wfcr, wfcr_another_spin, fft data, fft workarea + int batchSize = ModulePW::BatchedFFT::estimate_batch_size(4 * this->basis->nmaxgr * sizeof(T)); + if (std::is_same::value && batchSize > 1 && nbands > 1) { - /// - /// only occupied band should be calculated. - /// be care of when smearing_sigma is large, wg would less than 0 - /// + base_device::DEVICE_CPU *cpu_ctx; + double *wg_gpu = nullptr; + base_device::memory::resize_memory_op()(this->ctx, wg_gpu, nbands); + base_device::memory::synchronize_memory_op()(this->ctx, cpu_ctx, wg_gpu, &this->wg(ik, 0), nbands); - this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); + resmem_complex_op()(this->ctx, this->wfcr, this->basis->nmaxgr * batchSize, "ElecSPW::wfcr"); + resmem_complex_op()(this->ctx, this->wfcr_another_spin, this->basis->nrxx * batchSize, "ElecSPW::wfcr_a"); - this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik); + for (int i = 0; i < nbands; i += batchSize) + { + int remaining = nbands - i; + int current_batch = std::min(remaining, batchSize); + this->rhoBandK_spin4_batch(psi, wg_gpu + i, i, current_batch); + } + base_device::memory::delete_memory_op()(this->ctx, wg_gpu); + } + else + { + for (int ibnd = 0; ibnd < nbands; ibnd++) + { + /// + /// only occupied band should be calculated. + /// be care of when smearing_sigma is large, wg would less than 0 + /// - const auto w1 = static_cast(this->wg(ik, ibnd) / get_ucell_omega()); + this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); - if (w1 != 0.0) - { - // replaced by denghui at 20221110 - elecstate_pw_op()(this->ctx, - GlobalV::DOMAG, - GlobalV::DOMAG_Z, - this->basis->nrxx, - w1, - this->rho, - this->wfcr, - this->wfcr_another_spin); + this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik); + + const auto w1 = static_cast(this->wg(ik, ibnd) / get_ucell_omega()); + + if (w1 != 0.0) + { + // replaced by denghui at 20221110 + elecstate_pw_op()(this->ctx, + GlobalV::DOMAG, + GlobalV::DOMAG_Z, + this->basis->nrxx, + w1, + this->rho, + this->wfcr, + this->wfcr_another_spin); + } } } } @@ -243,6 +267,23 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) } } +template +void ElecStatePW::rhoBandK_spin4_batch(const psi::Psi& psi, + double* wg_gpu, + const int current_band, + const int batchSize) +{ + int ik = psi.get_current_k(); + int npw = psi.get_current_nbas(); + int npwx = npw / 2; + + this->basis->recip_to_real_batch(this->ctx, &psi(current_band,0), psi.get_nbasis(), this->wfcr, this->basis->nmaxgr, ik, batchSize); + this->basis->recip_to_real_batch(this->ctx, &psi(current_band,npwx), psi.get_nbasis(), this->wfcr_another_spin, this->basis->nrxx, ik, batchSize); + + elecstate_pw_batch_op()(this->ctx, GlobalV::DOMAG, GlobalV::DOMAG_Z, this->basis->nrxx, wg_gpu, get_ucell_omega(), this->rho, + this->wfcr, this->wfcr_another_spin, this->basis->nmaxgr, batchSize); +} + template void ElecStatePW::add_usrho(const psi::Psi& psi) { @@ -536,6 +577,6 @@ template class ElecStatePW, base_device::DEVICE_CPU>; #if ((defined __CUDA) || (defined __ROCM)) template class ElecStatePW, base_device::DEVICE_GPU>; template class ElecStatePW, base_device::DEVICE_GPU>; -#endif +#endif } // namespace elecstate diff --git a/source/module_elecstate/elecstate_pw.h b/source/module_elecstate/elecstate_pw.h index 0df2222763b..1ec0580308a 100644 --- a/source/module_elecstate/elecstate_pw.h +++ b/source/module_elecstate/elecstate_pw.h @@ -61,6 +61,8 @@ class ElecStatePW : public ElecState void init_rho_data(); + void rhoBandK_spin4_batch(const psi::Psi& psi, double *wg_gpu, const int current_band, const int batchSize); + Device * ctx = {}; bool init_rho = false; mutable T* vkb = nullptr; @@ -70,6 +72,7 @@ class ElecStatePW : public ElecState using meta_op = hamilt::meta_pw_op; using elecstate_pw_op = elecstate::elecstate_pw_op; + using elecstate_pw_batch_op = elecstate::elecstate_pw_batch_op; using setmem_var_op = base_device::memory::set_memory_op; using resmem_var_op = base_device::memory::resize_memory_op; diff --git a/source/module_elecstate/kernels/cuda/elecstate_op.cu b/source/module_elecstate/kernels/cuda/elecstate_op.cu index 59aa479b128..37ebe7d4b85 100644 --- a/source/module_elecstate/kernels/cuda/elecstate_op.cu +++ b/source/module_elecstate/kernels/cuda/elecstate_op.cu @@ -57,6 +57,52 @@ __global__ void elecstate_pw( } } +template +__global__ void elecstate_pw_batch( + const bool DOMAG, + const bool DOMAG_Z, + const int nrxx, + const double* weight, + const double volume, + FPTYPE* rho, + const thrust::complex* wfcr_batch, + const thrust::complex* wfcr_another_spin_batch, + const int ld_wfcr, + const int batchSize) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int batch = blockIdx.z; + if(idx >= nrxx || batch >= batchSize) {return;} + + FPTYPE w1 = weight[batch] / volume; + const thrust::complex* wfcr = wfcr_batch + batch * ld_wfcr; + const thrust::complex* wfcr_another_spin = wfcr_another_spin_batch + batch * ld_wfcr; + if (w1 == 0.0) return; + + atomicAdd(&rho[0 * nrxx + idx], w1 * (norm(wfcr[idx]) + norm(wfcr_another_spin[idx]))); + + if (DOMAG) { + atomicAdd(&rho[1 * nrxx + idx], w1 * 2.0 + * (wfcr[idx].real() * wfcr_another_spin[idx].real() + + wfcr[idx].imag() * wfcr_another_spin[idx].imag())); + atomicAdd(&rho[2 * nrxx + idx], w1 * 2.0 + * (wfcr[idx].real() * wfcr_another_spin[idx].imag() + - wfcr_another_spin[idx].real() * wfcr[idx].imag())); + atomicAdd(&rho[3 * nrxx + idx], w1 * (norm(wfcr[idx]) - norm(wfcr_another_spin[idx]))); + } + else if(DOMAG_Z) { + rho[1 * nrxx + idx] = 0; + rho[2 * nrxx + idx] = 0; + atomicAdd(&rho[3 * nrxx + idx], w1 * (norm(wfcr[idx]) - norm(wfcr_another_spin[idx]))); + } + else { + rho[0 * nrxx + idx] = 0; + rho[1 * nrxx + idx] = 0; + rho[2 * nrxx + idx] = 0; + rho[3 * nrxx + idx] = 0; + } +} + template void elecstate_pw_op::operator()(const base_device::DEVICE_GPU* ctx, const int& spin, @@ -94,7 +140,34 @@ void elecstate_pw_op::operator()(const base_dev cudaCheckOnDebug(); } +template +void elecstate_pw_batch_op::operator()(const base_device::DEVICE_GPU* ctx, + const bool& DOMAG, + const bool& DOMAG_Z, + const int& nrxx, + const double* w1, + const double volume, + FPTYPE** rho, + const std::complex* wfcr, + const std::complex* wfcr_another_spin, + const int ld_wfcr, + const int batchSize) +{ + dim3 block((nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + elecstate_pw_batch<<>>( + DOMAG, DOMAG_Z, nrxx, w1, volume, rho[0], + reinterpret_cast*>(wfcr), + reinterpret_cast*>(wfcr_another_spin), + ld_wfcr, batchSize + ); + + cudaCheckOnDebug(); +} + template struct elecstate_pw_op; template struct elecstate_pw_op; +template struct elecstate_pw_batch_op; +template struct elecstate_pw_batch_op; + } // namespace elecstate \ No newline at end of file diff --git a/source/module_elecstate/kernels/elecstate_op.cpp b/source/module_elecstate/kernels/elecstate_op.cpp index a1a48134bd3..9d89b85c6f9 100644 --- a/source/module_elecstate/kernels/elecstate_op.cpp +++ b/source/module_elecstate/kernels/elecstate_op.cpp @@ -81,6 +81,31 @@ struct elecstate_pw_op } }; +template +struct elecstate_pw_batch_op +{ + void operator()(const base_device::DEVICE_CPU* ctx, + const bool& DOMAG, + const bool& DOMAG_Z, + const int& nrxx, + const double* weight, + const double volume, + FPTYPE** rho, + const std::complex* wfcr, + const std::complex* wfcr_another_spin, + const int ld_wfcr, + const int batchSize) + { + for(int i = 0; i < batchSize; ++i) + { + elecstate_pw_op()(ctx, DOMAG, DOMAG_Z, nrxx, weight[i] / volume, rho, wfcr + ld_wfcr * i, + wfcr_another_spin + ld_wfcr * i); + } + } +}; + template struct elecstate_pw_op; template struct elecstate_pw_op; +template struct elecstate_pw_batch_op; +template struct elecstate_pw_batch_op; } // namespace elecstate \ No newline at end of file diff --git a/source/module_elecstate/kernels/elecstate_op.h b/source/module_elecstate/kernels/elecstate_op.h index c1e1550f238..2ba9a4225b8 100644 --- a/source/module_elecstate/kernels/elecstate_op.h +++ b/source/module_elecstate/kernels/elecstate_op.h @@ -7,7 +7,7 @@ namespace elecstate{ -template +template struct elecstate_pw_op { /// @brief Calculate psiToRho output within the band-by-band loop, NSPIN != 4 /// @@ -52,6 +52,37 @@ struct elecstate_pw_op { const std::complex* wfcr_another_spin); }; +template +struct elecstate_pw_batch_op { + /// @brief Calculate psiToRho output within the band-by-band loop, NSPIN == 4 + /// + /// Input Parameters + /// @param ctx - which device this function runs on + /// @param DOMAG - GlobalV::DOMAG + /// @param DOMAG_Z - GlobalV::DOMAG_Z + /// @param nrxx - number of planewaves + /// @param weight - weight vector + /// @param wfcr - input array, psi in real space + /// @param wfcr_another_spin - input array, psi in real space + /// @param ld_wfcr - leading dimension of wfcr and wfcr_another_spin + /// + /// Output Parameters + /// @param rho - electronic densities + void operator() ( + const Device* ctx, + const bool& DOMAG, + const bool& DOMAG_Z, + const int& nrxx, + const double* weight, + const double volume, + FPTYPE** rho, + const std::complex* wfcr, + const std::complex* wfcr_another_spin, + const int ld_wfcr, + const int batchSize); +}; + + #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM template struct elecstate_pw_op @@ -72,6 +103,22 @@ struct elecstate_pw_op const std::complex* wfcr, const std::complex* wfcr_another_spin); }; + +template +struct elecstate_pw_batch_op +{ + void operator()(const base_device::DEVICE_GPU* ctx, + const bool& DOMAG, + const bool& DOMAG_Z, + const int& nrxx, + const double* w1, + const double volume, + FPTYPE** rho, + const std::complex* wfcr, + const std::complex* wfcr_another_spin, + const int ld_wfcr, + const int batchSize); +}; #endif } // namespace elecstate diff --git a/source/module_elecstate/kernels/rocm/elecstate_op.hip.cu b/source/module_elecstate/kernels/rocm/elecstate_op.hip.cu index 80cb38100d2..76d236c1275 100644 --- a/source/module_elecstate/kernels/rocm/elecstate_op.hip.cu +++ b/source/module_elecstate/kernels/rocm/elecstate_op.hip.cu @@ -57,6 +57,52 @@ __global__ void elecstate_pw( } } +template +__global__ void elecstate_pw_batch( + const bool DOMAG, + const bool DOMAG_Z, + const int nrxx, + const double* weight, + const double volume, + FPTYPE* rho, + const thrust::complex* wfcr_batch, + const thrust::complex* wfcr_another_spin_batch, + const int ld_wfcr, + const int batchSize) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int batch = blockIdx.z; + if(idx >= nrxx || batch >= batchSize) {return;} + + FPTYPE w1 = weight[batch] / volume; + const thrust::complex* wfcr = wfcr_batch + batch * ld_wfcr; + const thrust::complex* wfcr_another_spin = wfcr_another_spin_batch + batch * ld_wfcr; + if (w1 == 0.0) return; + + atomicAdd(&rho[0 * nrxx + idx], w1 * (norm(wfcr[idx]) + norm(wfcr_another_spin[idx]))); + + if (DOMAG) { + atomicAdd(&rho[1 * nrxx + idx], w1 * 2.0 + * (wfcr[idx].real() * wfcr_another_spin[idx].real() + + wfcr[idx].imag() * wfcr_another_spin[idx].imag())); + atomicAdd(&rho[2 * nrxx + idx], w1 * 2.0 + * (wfcr[idx].real() * wfcr_another_spin[idx].imag() + - wfcr_another_spin[idx].real() * wfcr[idx].imag())); + atomicAdd(&rho[3 * nrxx + idx], w1 * (norm(wfcr[idx]) - norm(wfcr_another_spin[idx]))); + } + else if(DOMAG_Z) { + rho[1 * nrxx + idx] = 0; + rho[2 * nrxx + idx] = 0; + atomicAdd(&rho[3 * nrxx + idx], w1 * (norm(wfcr[idx]) - norm(wfcr_another_spin[idx]))); + } + else { + rho[0 * nrxx + idx] = 0; + rho[1 * nrxx + idx] = 0; + rho[2 * nrxx + idx] = 0; + rho[3 * nrxx + idx] = 0; + } +} + template void elecstate_pw_op::operator()(const base_device::DEVICE_GPU* ctx, const int& spin, @@ -94,6 +140,33 @@ void elecstate_pw_op::operator()(const base_dev hipCheckOnDebug(); } +template +void elecstate_pw_batch_op::operator()(const base_device::DEVICE_GPU* ctx, + const bool& DOMAG, + const bool& DOMAG_Z, + const int& nrxx, + const double* w1, + const double volume, + FPTYPE** rho, + const std::complex* wfcr, + const std::complex* wfcr_another_spin, + const int ld_wfcr, + const int batchSize) +{ + dim3 block((nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + hipLaunchKernelGGL(HIP_KERNEL_NAME(elecstate_pw_batch), block, dim3(THREADS_PER_BLOCK), 0, 0, + DOMAG, DOMAG_Z, nrxx, w1, volume, rho[0], + reinterpret_cast*>(wfcr), + reinterpret_cast*>(wfcr_another_spin), + ld_wfcr, batchSize + ); + + hipCheckOnDebug(); +} + template struct elecstate_pw_op; template struct elecstate_pw_op; + +template struct elecstate_pw_batch_op; +template struct elecstate_pw_batch_op; } \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu index 46071424157..d6e6557d0eb 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu @@ -21,6 +21,20 @@ __global__ void veff_pw( out[idx] *= in[idx]; } +template +__global__ void veff_pw_batch( + const int size, + thrust::complex* out, + int ld_out, + const FPTYPE* in, + int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= size || batch >= batchSize) {return;} + out[batch * ld_out + idx] *= in[idx]; +} + template __global__ void veff_pw( const int size, @@ -40,6 +54,32 @@ __global__ void veff_pw( out1[idx] = sdown; } +template +__global__ void veff_pw_batch( + const int size, + thrust::complex* out, + int ld_out, + thrust::complex* out1, + int ld_out1, + const FPTYPE* in, + int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= size || batch >= batchSize) {return;} + + thrust::complex *out_batch = out + batch * ld_out; + thrust::complex *out1_batch = out1 + batch * ld_out1; + thrust::complex sup = + out_batch[idx] * (in[0 * size + idx] + in[3 * size + idx]) + + out1_batch[idx] * (in[1 * size + idx] - thrust::complex(0.0, 1.0) * in[2 * size + idx]); + thrust::complex sdown = + out1_batch[idx] * (in[0 * size + idx] - in[3 * size + idx]) + + out_batch[idx] * (in[1 * size + idx] + thrust::complex(0.0, 1.0) * in[2 * size + idx]); + out_batch[idx] = sup; + out1_batch[idx] = sdown; +} + template void veff_pw_op::operator()(const base_device::DEVICE_GPU* dev, const int& size, @@ -72,7 +112,53 @@ void veff_pw_op::operator()(const base_device:: cudaCheckOnDebug(); } +template +void veff_pw_batch_op::operator()(const base_device::DEVICE_GPU* dev, + const int& size, + std::complex* out, + int ld_out, + const FPTYPE* in, + int batchSize) +{ + dim3 block((size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + veff_pw_batch<<>>( + size, // control params + reinterpret_cast*>(out), // array of data + ld_out, + in,// array of data + batchSize + ); + + cudaCheckOnDebug(); +} + +template +void veff_pw_batch_op::operator()(const base_device::DEVICE_GPU* dev, + const int& size, + std::complex* out, + int ld_out, + std::complex* out1, + int ld_out1, + const FPTYPE** in, + int batchSize) +{ + dim3 block((size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + veff_pw_batch<<>>( + size, // control params + reinterpret_cast*>(out), // array of data + ld_out, + reinterpret_cast*>(out1), // array of data + ld_out1, + in[0],// array of data + batchSize + ); + + cudaCheckOnDebug(); +} + template struct veff_pw_op; template struct veff_pw_op; +template struct veff_pw_batch_op; +template struct veff_pw_batch_op; } // namespace hamilt \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu b/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu index 58dd964fe5f..da77cbf1402 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu @@ -21,6 +21,20 @@ __global__ void veff_pw( out[idx] *= in[idx]; } +template +__global__ void veff_pw_batch( + const int size, + thrust::complex* out, + int ld_out, + const FPTYPE* in, + int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= size || batch >= batchSize) {return;} + out[batch * ld_out + idx] *= in[idx]; +} + template __global__ void veff_pw( const int size, @@ -40,6 +54,32 @@ __global__ void veff_pw( out1[idx] = sdown; } +template +__global__ void veff_pw_batch( + const int size, + thrust::complex* out, + int ld_out, + thrust::complex* out1, + int ld_out1, + const FPTYPE* in, + int batchSize) +{ + int batch = blockIdx.z; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= size || batch >= batchSize) {return;} + + thrust::complex *out_batch = out + batch * ld_out; + thrust::complex *out1_batch = out1 + batch * ld_out1; + thrust::complex sup = + out_batch[idx] * (in[0 * size + idx] + in[3 * size + idx]) + + out1_batch[idx] * (in[1 * size + idx] - thrust::complex(0.0, 1.0) * in[2 * size + idx]); + thrust::complex sdown = + out1_batch[idx] * (in[0 * size + idx] - in[3 * size + idx]) + + out_batch[idx] * (in[1 * size + idx] + thrust::complex(0.0, 1.0) * in[2 * size + idx]); + out_batch[idx] = sup; + out1_batch[idx] = sdown; +} + template void veff_pw_op::operator()(const base_device::DEVICE_GPU* dev, const int& size, @@ -72,7 +112,53 @@ void veff_pw_op::operator()(const base_device:: hipCheckOnDebug(); } +template +void veff_pw_batch_op::operator()(const base_device::DEVICE_GPU* dev, + const int& size, + std::complex* out, + int ld_out, + const FPTYPE* in, + int batchSize) +{ + dim3 block((size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + hipLaunchKernelGGL(HIP_KERNEL_NAME(veff_pw_batch), block, dim3(THREADS_PER_BLOCK), 0, 0, + size, // control params + reinterpret_cast*>(out), // array of data + ld_out, + in,// array of data + batchSize + ); + + hipCheckOnDebug(); +} + +template +void veff_pw_batch_op::operator()(const base_device::DEVICE_GPU* dev, + const int& size, + std::complex* out, + int ld_out, + std::complex* out1, + int ld_out1, + const FPTYPE** in, + int batchSize) +{ + dim3 block((size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, 1, batchSize); + hipLaunchKernelGGL(HIP_KERNEL_NAME(veff_pw_batch), block, dim3(THREADS_PER_BLOCK), 0, 0, + size, // control params + reinterpret_cast*>(out), // array of data + ld_out, + reinterpret_cast*>(out1), // array of data + ld_out1, + in[0],// array of data + batchSize + ); + + hipCheckOnDebug(); +} + template struct veff_pw_op; template struct veff_pw_op; +template struct veff_pw_batch_op; +template struct veff_pw_batch_op; } // namespace hamilt \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp b/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp index ce7e9888456..57c523ae6ef 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp @@ -40,8 +40,44 @@ struct veff_pw_op } }; +template +struct veff_pw_batch_op +{ + void operator()( + const base_device::DEVICE_CPU* dev, + const int& size, + std::complex* out, + int ld_out, + const FPTYPE* in, + int batchSize) + { + for(int i = 0; i < batchSize; ++i) + { + veff_pw_op()(dev, size, out + ld_out * i, in); + } + } + + void operator() ( + const base_device::DEVICE_CPU* dev, + const int& size, + std::complex* out, + int ld_out, + std::complex* out1, + int ld_out1, + const FPTYPE** in, + int batchSize) + { + for(int i = 0; i < batchSize; ++i) + { + veff_pw_op()(dev, size, out + ld_out * i, out1 + ld_out1 * i, in); + } + } +}; + template struct veff_pw_op; template struct veff_pw_op; +template struct veff_pw_batch_op; +template struct veff_pw_batch_op; } // namespace hamilt diff --git a/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.h b/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.h index 4480fc5388a..596ba356dd6 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.h +++ b/source/module_hamilt_pw/hamilt_pwdft/kernels/veff_op.h @@ -51,6 +51,62 @@ struct veff_pw_op { const FPTYPE** in); }; +template +struct veff_pw_batch_op { + /// @brief Compute the effective potential of hPsi in real space, + /// out[ir] *= in[ir]; + /// + /// Input Parameters + /// \param dev : the type of computing device + /// \param size : array size + /// \param ld_out : leading dimension of out + /// \param in : input array, elecstate::Potential::v_effective + /// \param batchSize : batch size for once calcution + /// + /// Output Parameters + /// \param out : output array + void operator() ( + const Device* dev, + const int& size, + std::complex* out, + int ld_out, + const FPTYPE* in, + int batchSize); + + /// @brief Compute the effective potential of hPsi in real space with NSPIN > 2, + /// + /// out[ir] = out[ir] * (in[0][ir] + in[3][ir]) + /// + out1[ir] + /// * (in[1][ir] + /// - std::complex(0.0, 1.0) * in[2][ir]); + /// + /// out1[ir] = out1[ir] * (in[0][ir] - in[3][ir]) + /// + out[ir] + /// * (in[1][ir] + /// + std::complex(0.0, 1.0) * in[2][ir]); + /// + /// Input Parameters + /// \param dev : the type of computing device + /// \param size : array size + /// \param ld_out : leading dimension of out + /// \param ld_out1 : leading dimension of out1 + /// \param in : input array, elecstate::Potential::v_effective + /// \param batchSize : batch size for once calcution + /// + /// Output Parameters + /// \param out : output array 1 + /// \param out1 : output array 2 + void operator() ( + const Device* dev, + const int& size, + std::complex* out, + int ld_out, + std::complex* out1, + int ld_out1, + const FPTYPE** in, + int batchSize); +}; + #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM // Partially specialize functor for base_device::GpuDevice. template @@ -64,6 +120,22 @@ struct veff_pw_op std::complex* out1, const FPTYPE** in); }; + +template +struct veff_pw_batch_op +{ + void operator()(const base_device::DEVICE_GPU* dev, const int& size, std::complex* out, int ld_out, const FPTYPE* in, int batchSize); + + void operator()(const base_device::DEVICE_GPU* dev, + const int& size, + std::complex* out, + int ld_out, + std::complex* out1, + int ld_out1, + const FPTYPE** in, + int batchSize); +}; + #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM } // namespace hamilt #endif //MODULE_HAMILT_VEFF_H \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp index 8fed260c1c5..7026b22c0c2 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp @@ -16,12 +16,13 @@ Veff>::Veff(const int* isk_in, this->cal_type = calculation_type::pw_veff; this->isk = isk_in; this->veff = veff_in; - //note: "veff = nullptr" means that this core does not treat potential but still treats wf. + //note: "veff = nullptr" means that this core does not treat potential but still treats wf. this->veff_row = veff_row; this->veff_col = veff_col; this->wfcpw = wfcpw_in; - resmem_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr, "Veff::porter"); - resmem_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr, "Veff::porter1"); + this->malloc_porter(this->wfcpw->nmaxgr); + // resmem_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr, "Veff::porter"); + // resmem_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr, "Veff::porter1"); if (this->isk == nullptr || this->wfcpw == nullptr) { ModuleBase::WARNING_QUIT("VeffPW", "Constuctor of Operator::VeffPW is failed, please check your code!"); } @@ -32,6 +33,7 @@ Veff>::~Veff() { delmem_complex_op()(this->ctx, this->porter); delmem_complex_op()(this->ctx, this->porter1); + this->porter_length = 0; } template @@ -44,11 +46,33 @@ void Veff>::act( const int ngk_ik)const { ModuleBase::timer::tick("Operator", "VeffPW"); + int loop_batches = (nbands + npol - 1) / npol; +#if defined(__CUDA) || defined(__ROCM) + // addtional memory: porter, porter1, fft data, fft workarea + int batchSize = ModulePW::BatchedFFT::estimate_batch_size(4 * this->wfcpw->nmaxgr * sizeof(T)); +#else + int batchSize = 1; +#endif + + if (std::is_same::value && batchSize > 1 && loop_batches > 1) // this->device = AbacusDevice_t::UnKnown ? + { + int max_npw = nbasis / npol; + int ld_tmp = max_npw * npol; + + for (int i = 0; i < loop_batches; i += batchSize) + { + int remaining = loop_batches - i; + int current_batch = std::min(remaining, batchSize); + this->act_batch(nbands, nbasis, npol, tmpsi_in + i * ld_tmp, tmhpsi + i * ld_tmp, ngk_ik, current_batch); + } + ModuleBase::timer::tick("Operator", "VeffPW"); + return; + } int max_npw = nbasis / npol; const int current_spin = this->isk[this->ik]; - - // T *porter = new T[wfcpw->nmaxgr]; + + for (int ib = 0; ib < nbands; ib += npol) { if (npol == 1) @@ -108,6 +132,48 @@ void Veff>::act( ModuleBase::timer::tick("Operator", "VeffPW"); } +template +void Veff>::act_batch( + const int nbands, + const int nbasis, + const int npol, + const T* tmpsi_in, + T* tmhpsi, + const int ngk_ik, + const int batchSize)const +{ + int max_npw = nbasis / npol; + const int current_spin = this->isk[this->ik]; + + this->malloc_porter(this->wfcpw->nmaxgr * batchSize); + + if (npol == 1) + { + wfcpw->recip_to_real_batch(this->ctx, tmpsi_in, max_npw * npol, this->porter, this->wfcpw->nmaxgr, this->ik, batchSize); + if (this->veff_col != 0) + { + veff_batch_op()(this->ctx, this->veff_col, this->porter, this->wfcpw->nmaxgr, this->veff + current_spin * this->veff_col, batchSize); + } + wfcpw->real_to_recip_batch(this->ctx, this->porter, this->wfcpw->nmaxgr, tmhpsi, max_npw * npol, this->ik, batchSize, true); + } + else + { + const Real* current_veff[4]; + for(int is = 0; is < 4; is++) { + current_veff[is] = this->veff + is * this->veff_col ; // for CPU device + } + wfcpw->recip_to_real_batch(this->ctx, tmpsi_in, max_npw * npol, this->porter, this->wfcpw->nmaxgr, this->ik, batchSize); + wfcpw->recip_to_real_batch(this->ctx, tmpsi_in + max_npw, max_npw * npol, this->porter1, this->wfcpw->nmaxgr, this->ik, batchSize); + if(this->veff_col != 0) + { + veff_batch_op()(this->ctx, this->veff_col, this->porter, this->wfcpw->nmaxgr, this->porter1, this->wfcpw->nmaxgr, current_veff, batchSize); + } + wfcpw->real_to_recip_batch(this->ctx, this->porter, this->wfcpw->nmaxgr, tmhpsi, max_npw * npol, this->ik, batchSize, true); + wfcpw->real_to_recip_batch(this->ctx, this->porter1, this->wfcpw->nmaxgr, tmhpsi + max_npw, max_npw * npol, this->ik, batchSize, true); + + } +} + template template hamilt::Veff>::Veff(const Veff> *veff) { @@ -118,8 +184,9 @@ hamilt::Veff>::Veff(const Veff this->veff_col = veff->get_veff_col(); this->veff_row = veff->get_veff_row(); this->wfcpw = veff->get_wfcpw(); - resmem_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr); - resmem_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr); + this->malloc_porter(this->wfcpw->nmaxgr); + // resmem_complex_op()(this->ctx, this->porter, this->wfcpw->nmaxgr); + // resmem_complex_op()(this->ctx, this->porter1, this->wfcpw->nmaxgr); this->veff = veff->get_veff(); if (this->isk == nullptr || this->veff == nullptr || this->wfcpw == nullptr) { ModuleBase::WARNING_QUIT("VeffPW", "Constuctor of Operator::VeffPW is failed, please check your code!"); diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h index c8b19047f4f..5d6ca26bf8f 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.h @@ -43,6 +43,14 @@ class Veff> : public OperatorPW T* tmhpsi, const int ngk_ik = 0)const override; + void act_batch(const int nbands, + const int nbasis, + const int npol, + const T* tmpsi_in, + T* tmhpsi, + const int ngk_ik = 0, + const int batchSize = 1)const; + // denghui added for copy constructor at 20221105 const Real *get_veff() const {return this->veff;} int get_veff_col() const {return this->veff_col;} @@ -65,13 +73,22 @@ class Veff> : public OperatorPW int veff_col = 0; int veff_row = 0; const Real *veff = nullptr, *h_veff = nullptr, *d_veff = nullptr; - T *porter = nullptr; - T *porter1 = nullptr; + mutable T *porter = nullptr; + mutable T *porter1 = nullptr; + mutable int porter_length = 0; base_device::AbacusDevice_t device = {}; using veff_op = veff_pw_op; + using veff_batch_op = veff_pw_batch_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; + + void malloc_porter(int size) const{ + if (size <= this->porter_length) return; + resmem_complex_op()(this->ctx, this->porter, size, "Veff::porter"); + resmem_complex_op()(this->ctx, this->porter1, size, "Veff::porter1"); + this->porter_length = size; + } }; } // namespace hamilt diff --git a/source/module_io/read_input_item_general.cpp b/source/module_io/read_input_item_general.cpp index 91e4cab3c36..3982850143b 100644 --- a/source/module_io/read_input_item_general.cpp +++ b/source/module_io/read_input_item_general.cpp @@ -583,6 +583,12 @@ void ReadInput::item_general() read_sync_string(input.device); this->add_item(item); } + { + Input_Item item("fft_batch_size"); + item.annotation = "the fft batch size for ABACUS GPU"; + read_sync_int(input.fft_batch_size); + this->add_item(item); + } { Input_Item item("precision"); item.annotation = "the computing precision for ABACUS"; diff --git a/source/module_parameter/input_parameter.h b/source/module_parameter/input_parameter.h index 17ef09ed165..8799c9e2747 100644 --- a/source/module_parameter/input_parameter.h +++ b/source/module_parameter/input_parameter.h @@ -74,6 +74,7 @@ struct Input_para ///< and wavefunction. 0: output only when ion steps are finished int elpa_num_thread = -1; ///< Number of threads need to use in elpa std::string device = "cpu"; + int fft_batch_size = 0; // for gpu : 0: auto-detect the batch size. 1: loop calcaulate. >1 batch calcaulate. std::string precision = "double"; // ============== #Parameters (2.PW) =========================== @@ -576,7 +577,8 @@ struct Input_para std::string lr_solver = "dav"; ///< the eigensolver for LR-TDDFT double lr_thr = 1e-2; ///< convergence threshold of the LR-TDDFT eigensolver bool out_wfc_lr = false; ///< whether to output the eigenvectors (excitation amplitudes) in the particle-hole basis - std::vector abs_wavelen_range = { 0., 0. }; ///< the range of wavelength(nm) to output the absorption spectrum + std::vector abs_wavelen_range = { 0., 0. }; ///< the range of wavelength(nm) to output the absorption spectrum double abs_broadening = 0.01; ///< the broadening (eta) for LR-TDDFT absorption spectrum + }; #endif \ No newline at end of file From fa37449a1bc9f7e64a2b52ca246065d90179ab52 Mon Sep 17 00:00:00 2001 From: Tianxiang Wang Date: Thu, 16 Oct 2025 09:59:59 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Fix=20compile=20error=20Signed-off-by?= =?UTF-8?q?=EF=BC=9ATianxiang=20Wang,Contri?= =?UTF-8?q?buted=20under=20MetaX=20Integrated=20Circuits=20(Shanghai)=20Co?= =?UTF-8?q?.,=20Ltd.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/module_basis/module_pw/pw_basis_k.h | 2 -- source/module_elecstate/elecstate_pw.cpp | 2 ++ source/module_elecstate/fp_energy.cpp | 4 ++++ source/module_esolver/esolver_ks.cpp | 2 ++ source/module_hsolver/diago_dav_subspace.cpp | 3 ++- 5 files changed, 10 insertions(+), 3 deletions(-) diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index 491f4e5e6b8..18a8639a658 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -152,7 +152,6 @@ class PW_Basis_K : public PW_Basis const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) -#if defined(__CUDA) || defined(__ROCM) template void real_to_recip_batch(const Device* ctx, std::complex* in, @@ -173,7 +172,6 @@ class PW_Basis_K : public PW_Basis const int batchSize, const bool add = false, const FPTYPE factor = 1.0)const; // in:(nz, ns) ; out(nplane,nx*ny) -#endif public: //operator: //get (G+K)^2: diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index e4b33fd2361..69dd3fd26df 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -172,6 +172,7 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) { int npwx = npw / 2; +#if defined(__CUDA) || defined(__ROCM) // additional memeory : wfcr, wfcr_another_spin, fft data, fft workarea int batchSize = ModulePW::BatchedFFT::estimate_batch_size(4 * this->basis->nmaxgr * sizeof(T)); if (std::is_same::value && batchSize > 1 && nbands > 1) @@ -193,6 +194,7 @@ void ElecStatePW::rhoBandK(const psi::Psi& psi) base_device::memory::delete_memory_op()(this->ctx, wg_gpu); } else +#endif { for (int ibnd = 0; ibnd < nbands; ibnd++) { diff --git a/source/module_elecstate/fp_energy.cpp b/source/module_elecstate/fp_energy.cpp index 6b8d4cbbbde..646fd8b4ae2 100644 --- a/source/module_elecstate/fp_energy.cpp +++ b/source/module_elecstate/fp_energy.cpp @@ -36,7 +36,9 @@ double fenergy::calculate_etot() etot += (ecore + epawdc); } #endif +#ifdef __MPI Parallel_Common::bcast_double(etot); +#endif return etot; } @@ -61,7 +63,9 @@ double fenergy::calculate_harris() etot_harris += (ecore + epawdc); } #endif +#ifdef __MPI Parallel_Common::bcast_double(etot_harris); +#endif return etot_harris; } diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 989a39428d5..d99d63d4fcc 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -535,7 +535,9 @@ void ESolver_KS::runner(const int istep, UnitCell& ucell) // EState should be used after it is constructed. drho = p_chgmix->get_drho(pelec->charge, GlobalV::nelec); +#ifdef __MPI Parallel_Common::bcast_double(drho); +#endif double hsolver_error = 0.0; if (firstscf) { diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 711015db26b..b14527c9ff3 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -502,8 +502,9 @@ bool Diago_DavSubspace::diag_zhegvx(const int& nbase, } } } - +#ifdef __MPI MPI_Bcast(&fail_info, 1, MPI_INT, 0, this->diag_comm.comm); +#endif if(fail_info != 0) { return false;