diff --git a/source/source_basis/module_pw/pw_transform.cpp b/source/source_basis/module_pw/pw_transform.cpp index 220b353e9d4..06eecba1c9f 100644 --- a/source/source_basis/module_pw/pw_transform.cpp +++ b/source/source_basis/module_pw/pw_transform.cpp @@ -5,11 +5,22 @@ #include "pw_basis.h" #include "pw_gatherscatter.h" +#include #include #include namespace ModulePW { +namespace +{ +constexpr int pw_transform_cache_block = 1024; + +inline int block_end(const int begin, const int size) +{ + return std::min(begin + pw_transform_cache_block, size); +} +} // namespace + // const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() { // static const base_device::DEVICE_CPU* default_device_cpu; // return default_device_cpu; @@ -34,28 +45,48 @@ void PW_Basis::real2recip(const std::complex* in, const int npw_ = this->npw; const int nxyz_ = this->nxyz; const int* ig2isz_ = this->ig2isz; + const std::complex* in_ = in; + std::complex* auxr = this->fft_bundle.get_auxr_data(); + std::complex* auxg = this->fft_bundle.get_auxg_data(); + ModuleBase::timer::start(this->classname, "real2recip_copy_r"); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ir = 0; ir < nrxx_; ++ir) + for (int ib = 0; ib < nrxx_; ib += pw_transform_cache_block) { - this->fft_bundle.get_auxr_data()[ir] = in[ir]; + const int iend = block_end(ib, nrxx_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) + { + auxr[ir] = in_[ir]; + } } - this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); + ModuleBase::timer::end(this->classname, "real2recip_copy_r"); + this->fft_bundle.fftxyfor(auxr, auxr); - this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); + this->gatherp_scatters(auxr, auxg); - this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); + this->fft_bundle.fftzfor(auxg, auxg); + ModuleBase::timer::start(this->classname, "real2recip_copy_g"); if (add) { FPTYPE tmpfac = factor / FPTYPE(nxyz_); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < npw_; ++ig) + for (int ib = 0; ib < npw_; ib += pw_transform_cache_block) { - out[ig] += tmpfac * this->fft_bundle.get_auxg_data()[ig2isz_[ig]]; + const int iend = block_end(ib, npw_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ig = ib; ig < iend; ++ig) + { + out[ig] += tmpfac * auxg[ig2isz_[ig]]; + } } } else @@ -64,11 +95,19 @@ void PW_Basis::real2recip(const std::complex* in, #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < npw_; ++ig) + for (int ib = 0; ib < npw_; ib += pw_transform_cache_block) { - out[ig] = tmpfac * this->fft_bundle.get_auxg_data()[ig2isz_[ig]]; + const int iend = block_end(ib, npw_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ig = ib; ig < iend; ++ig) + { + out[ig] = tmpfac * auxg[ig2isz_[ig]]; + } } } + ModuleBase::timer::end(this->classname, "real2recip_copy_g"); ModuleBase::timer::end(this->classname, "real2recip"); } @@ -90,46 +129,73 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo const int nx_ = this->nx; const int ny_ = this->ny; const int nplane_ = this->nplane; + const FPTYPE* in_ = in; + std::complex* auxr = this->fft_bundle.get_auxr_data(); + std::complex* auxg = this->fft_bundle.get_auxg_data(); + FPTYPE* rspace = this->fft_bundle.get_rspace_data(); + ModuleBase::timer::start(this->classname, "real2recip_copy_r"); if (this->gamma_only) { const int npy = ny_ * nplane_; + const int nreal = nx_ * npy; #ifdef _OPENMP -#pragma omp parallel for collapse(2) schedule(static) +#pragma omp parallel for schedule(static) #endif - for (int ix = 0; ix < nx_; ++ix) + for (int ib = 0; ib < nreal; ib += pw_transform_cache_block) { - for (int ipy = 0; ipy < npy; ++ipy) + const int iend = block_end(ib, nreal); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) { - this->fft_bundle.get_rspace_data()[ix * npy + ipy] = in[ix * npy + ipy]; + rspace[ir] = in_[ir]; } } - this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data(), fft_bundle.get_auxr_data()); + ModuleBase::timer::end(this->classname, "real2recip_copy_r"); + this->fft_bundle.fftxyr2c(rspace, auxr); } else { #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ir = 0; ir < nrxx_; ++ir) + for (int ib = 0; ib < nrxx_; ib += pw_transform_cache_block) { - this->fft_bundle.get_auxr_data()[ir] = std::complex(in[ir], 0); + const int iend = block_end(ib, nrxx_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) + { + auxr[ir] = std::complex(in_[ir], 0); + } } - this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); + ModuleBase::timer::end(this->classname, "real2recip_copy_r"); + this->fft_bundle.fftxyfor(auxr, auxr); } - this->gatherp_scatters(this->fft_bundle.get_auxr_data(), this->fft_bundle.get_auxg_data()); + this->gatherp_scatters(auxr, auxg); - this->fft_bundle.fftzfor(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); + this->fft_bundle.fftzfor(auxg, auxg); + ModuleBase::timer::start(this->classname, "real2recip_copy_g"); if (add) { FPTYPE tmpfac = factor / FPTYPE(nxyz_); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < npw_; ++ig) + for (int ib = 0; ib < npw_; ib += pw_transform_cache_block) { - out[ig] += tmpfac * this->fft_bundle.get_auxg_data()[ig2isz_[ig]]; + const int iend = block_end(ib, npw_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ig = ib; ig < iend; ++ig) + { + out[ig] += tmpfac * auxg[ig2isz_[ig]]; + } } } else @@ -138,11 +204,19 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex* out, const boo #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < npw_; ++ig) + for (int ib = 0; ib < npw_; ib += pw_transform_cache_block) { - out[ig] = tmpfac * this->fft_bundle.get_auxg_data()[ig2isz_[ig]]; + const int iend = block_end(ib, npw_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ig = ib; ig < iend; ++ig) + { + out[ig] = tmpfac * auxg[ig2isz_[ig]]; + } } } + ModuleBase::timer::end(this->classname, "real2recip_copy_g"); ModuleBase::timer::end(this->classname, "real2recip"); } @@ -166,35 +240,62 @@ void PW_Basis::recip2real(const std::complex* in, const int npw_ = this->npw; const int nrxx_ = this->nrxx; const int* ig2isz_ = this->ig2isz; + const int nstnz_ = nst_ * nz_; + std::complex* auxg = this->fft_bundle.get_auxg_data(); + std::complex* auxr = this->fft_bundle.get_auxr_data(); + ModuleBase::timer::start(this->classname, "recip2real_copy_g"); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int i = 0; i < nst_ * nz_; ++i) + for (int ib = 0; ib < nstnz_; ib += pw_transform_cache_block) { - fft_bundle.get_auxg_data()[i] = std::complex(0, 0); + const int iend = block_end(ib, nstnz_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int i = ib; i < iend; ++i) + { + auxg[i] = std::complex(0, 0); + } } #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < npw_; ++ig) + for (int ib = 0; ib < npw_; ib += pw_transform_cache_block) { - this->fft_bundle.get_auxg_data()[ig2isz_[ig]] = in[ig]; + const int iend = block_end(ib, npw_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ig = ib; ig < iend; ++ig) + { + auxg[ig2isz_[ig]] = in[ig]; + } } - this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); + ModuleBase::timer::end(this->classname, "recip2real_copy_g"); + this->fft_bundle.fftzbac(auxg, auxg); - this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); + this->gathers_scatterp(auxg, auxr); - this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(), this->fft_bundle.get_auxr_data()); + this->fft_bundle.fftxybac(auxr, auxr); + ModuleBase::timer::start(this->classname, "recip2real_copy_r"); if (add) { #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ir = 0; ir < nrxx_; ++ir) + for (int ib = 0; ib < nrxx_; ib += pw_transform_cache_block) { - out[ir] += factor * this->fft_bundle.get_auxr_data()[ir]; + const int iend = block_end(ib, nrxx_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) + { + out[ir] += factor * auxr[ir]; + } } } else @@ -202,11 +303,19 @@ void PW_Basis::recip2real(const std::complex* in, #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ir = 0; ir < nrxx_; ++ir) + for (int ib = 0; ib < nrxx_; ib += pw_transform_cache_block) { - out[ir] = this->fft_bundle.get_auxr_data()[ir]; + const int iend = block_end(ib, nrxx_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) + { + out[ir] = auxr[ir]; + } } } + ModuleBase::timer::end(this->classname, "recip2real_copy_r"); ModuleBase::timer::end(this->classname, "recip2real"); } @@ -229,69 +338,108 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo const int ny_ = this->ny; const int nplane_ = this->nplane; const int* ig2isz_ = this->ig2isz; + const int nstnz_ = nst_ * nz_; + std::complex* auxg = this->fft_bundle.get_auxg_data(); + std::complex* auxr = this->fft_bundle.get_auxr_data(); + FPTYPE* rspace = this->fft_bundle.get_rspace_data(); + ModuleBase::timer::start(this->classname, "recip2real_copy_g"); #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int i = 0; i < nst_ * nz_; ++i) + for (int ib = 0; ib < nstnz_; ib += pw_transform_cache_block) { - fft_bundle.get_auxg_data()[i] = std::complex(0, 0); + const int iend = block_end(ib, nstnz_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int i = ib; i < iend; ++i) + { + auxg[i] = std::complex(0, 0); + } } #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ig = 0; ig < npw_; ++ig) + for (int ib = 0; ib < npw_; ib += pw_transform_cache_block) { - this->fft_bundle.get_auxg_data()[ig2isz_[ig]] = in[ig]; + const int iend = block_end(ib, npw_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ig = ib; ig < iend; ++ig) + { + auxg[ig2isz_[ig]] = in[ig]; + } } - this->fft_bundle.fftzbac(fft_bundle.get_auxg_data(), fft_bundle.get_auxg_data()); + ModuleBase::timer::end(this->classname, "recip2real_copy_g"); + this->fft_bundle.fftzbac(auxg, auxg); - this->gathers_scatterp(this->fft_bundle.get_auxg_data(), this->fft_bundle.get_auxr_data()); + this->gathers_scatterp(auxg, auxr); if (this->gamma_only) { - this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data(), fft_bundle.get_rspace_data()); + this->fft_bundle.fftxyc2r(auxr, rspace); const int npy = ny_ * nplane_; + const int nreal = nx_ * npy; + ModuleBase::timer::start(this->classname, "recip2real_copy_r"); if (add) { #ifdef _OPENMP -#pragma omp parallel for collapse(2) schedule(static) +#pragma omp parallel for schedule(static) #endif - for (int ix = 0; ix < nx_; ++ix) + for (int ib = 0; ib < nreal; ib += pw_transform_cache_block) { - for (int ipy = 0; ipy < npy; ++ipy) + const int iend = block_end(ib, nreal); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) { - out[ix * npy + ipy] += factor * this->fft_bundle.get_rspace_data()[ix * npy + ipy]; + out[ir] += factor * rspace[ir]; } } } else { #ifdef _OPENMP -#pragma omp parallel for collapse(2) schedule(static) +#pragma omp parallel for schedule(static) #endif - for (int ix = 0; ix < nx_; ++ix) + for (int ib = 0; ib < nreal; ib += pw_transform_cache_block) { - for (int ipy = 0; ipy < npy; ++ipy) + const int iend = block_end(ib, nreal); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) { - out[ix * npy + ipy] = this->fft_bundle.get_rspace_data()[ix * npy + ipy]; + out[ir] = rspace[ir]; } } } + ModuleBase::timer::end(this->classname, "recip2real_copy_r"); } else { - this->fft_bundle.fftxybac(fft_bundle.get_auxr_data(), fft_bundle.get_auxr_data()); + this->fft_bundle.fftxybac(auxr, auxr); + ModuleBase::timer::start(this->classname, "recip2real_copy_r"); if (add) { #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ir = 0; ir < nrxx_; ++ir) + for (int ib = 0; ib < nrxx_; ib += pw_transform_cache_block) { - out[ir] += factor * this->fft_bundle.get_auxr_data()[ir].real(); + const int iend = block_end(ib, nrxx_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) + { + out[ir] += factor * auxr[ir].real(); + } } } else @@ -299,11 +447,19 @@ void PW_Basis::recip2real(const std::complex* in, FPTYPE* out, const boo #ifdef _OPENMP #pragma omp parallel for schedule(static) #endif - for (int ir = 0; ir < nrxx_; ++ir) + for (int ib = 0; ib < nrxx_; ib += pw_transform_cache_block) { - out[ir] = this->fft_bundle.get_auxr_data()[ir].real(); + const int iend = block_end(ib, nrxx_); +#ifdef _OPENMP +#pragma omp simd +#endif + for (int ir = ib; ir < iend; ++ir) + { + out[ir] = auxr[ir].real(); + } } } + ModuleBase::timer::end(this->classname, "recip2real_copy_r"); } ModuleBase::timer::end(this->classname, "recip2real"); } @@ -340,4 +496,4 @@ template void PW_Basis::recip2real(const std::complex* in, std::complex* out, const bool add, const double factor) const; -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW