diff --git a/source/source_basis/module_pw/CMakeLists.txt b/source/source_basis/module_pw/CMakeLists.txt index 912772e0573..bdab5e9f75b 100644 --- a/source/source_basis/module_pw/CMakeLists.txt +++ b/source/source_basis/module_pw/CMakeLists.txt @@ -43,6 +43,25 @@ add_library( ${objects} ) +add_executable( + MODULE_PW_cache_bench + test_serial/pw_cache_bench.cpp +) + +target_link_libraries( + MODULE_PW_cache_bench + parameter + ${math_libs} + planewave + device + base + Threads::Threads +) + +if(USE_OPENMP) + target_link_libraries(MODULE_PW_cache_bench OpenMP::OpenMP_CXX) +endif() + if (USE_DSP) target_link_libraries(planewave PRIVATE ${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a) diff --git a/source/source_basis/module_pw/pw_basis.cpp b/source/source_basis/module_pw/pw_basis.cpp index 549fec8e5a4..77a3626994c 100644 --- a/source/source_basis/module_pw/pw_basis.cpp +++ b/source/source_basis/module_pw/pw_basis.cpp @@ -5,6 +5,7 @@ #include "source_base/timer.h" #include "source_base/global_function.h" +#include namespace ModulePW { @@ -28,17 +29,13 @@ PW_Basis:: ~PW_Basis() delete[] fftixy2ip; delete[] nst_per; delete[] npw_per; - delete[] gdirect; - delete[] gcar; - delete[] gg; delete[] startz; delete[] numz; delete[] numg; delete[] numr; delete[] startg; delete[] startr; - delete[] ig2igg; - delete[] gg_uniq; + this->clear_owned_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -48,6 +45,91 @@ PW_Basis:: ~PW_Basis() #endif } +void PW_Basis::clear_owned_cache() +{ + std::lock_guard guard(this->cache_mutex); + this->invalidate_cache_unlocked(); +} + +PW_Basis::CacheStats PW_Basis::get_cache_stats() const +{ + std::lock_guard guard(this->cache_mutex); + return this->get_cache_stats_unlocked(); +} + +PW_Basis::CacheStats PW_Basis::get_cache_stats_unlocked() const +{ + CacheStats stats; + stats.local_pw_hits = this->local_pw_cache_hits.load(); + stats.local_pw_misses = this->local_pw_cache_misses.load(); + stats.uniqgg_hits = this->uniqgg_cache_hits.load(); + stats.uniqgg_misses = this->uniqgg_cache_misses.load(); + const bool has_local_pw_cache = this->local_pw_cache_valid.load() + && this->npw > 0 + && this->gg != nullptr + && this->gdirect != nullptr + && this->gcar != nullptr; + const bool has_uniqgg_cache = this->uniqgg_cache_valid.load() + && this->ngg > 0 + && this->ig2igg != nullptr + && this->gg_uniq != nullptr; + if (has_local_pw_cache) + { + stats.cache_bytes += sizeof(double) * this->npw; + stats.cache_bytes += sizeof(ModuleBase::Vector3) * this->npw * 2; + } + if (has_uniqgg_cache) + { + stats.cache_bytes += sizeof(int) * this->npw; + stats.cache_bytes += sizeof(double) * this->ngg; + } + return stats; +} + +void PW_Basis::reset_cache_stats() +{ + this->local_pw_cache_hits.store(0); + this->local_pw_cache_misses.store(0); + this->uniqgg_cache_hits.store(0); + this->uniqgg_cache_misses.store(0); +} + +PW_Basis::CacheSignature PW_Basis::make_cache_signature() const +{ + CacheSignature signature; + signature.lat0 = this->lat0; + signature.tpiba = this->tpiba; + signature.tpiba2 = this->tpiba2; + signature.nx = this->nx; + signature.ny = this->ny; + signature.nz = this->nz; + signature.fftnx = this->fftnx; + signature.fftny = this->fftny; + signature.fftnz = this->fftnz; + signature.npw = this->npw; + signature.G = this->G; + signature.GT = this->GT; + signature.GGT = this->GGT; + return signature; +} + +bool PW_Basis::cache_signature_matches(const CacheSignature& signature) const +{ + return signature.lat0 == this->lat0 + && signature.tpiba == this->tpiba + && signature.tpiba2 == this->tpiba2 + && signature.nx == this->nx + && signature.ny == this->ny + && signature.nz == this->nz + && signature.fftnx == this->fftnx + && signature.fftny == this->fftny + && signature.fftnz == this->fftnz + && signature.npw == this->npw + && std::memcmp(&signature.G, &this->G, sizeof(ModuleBase::Matrix3)) == 0 + && std::memcmp(&signature.GT, &this->GT, sizeof(ModuleBase::Matrix3)) == 0 + && std::memcmp(&signature.GGT, &this->GGT, sizeof(ModuleBase::Matrix3)) == 0; +} + /// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall @@ -138,10 +220,33 @@ void PW_Basis::collect_local_pw() { return; } + ModuleBase::timer::start(this->classname, "collect_local_pw"); + std::lock_guard guard(this->cache_mutex); + if (this->local_pw_cache_valid.load() + && this->cache_signature_matches(this->local_pw_cache_signature)) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_cache_hit"); + this->local_pw_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw_cache_hit"); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + ModuleBase::timer::start(this->classname, "collect_local_pw_cache_build"); + this->local_pw_cache_misses.fetch_add(1); this->ig_gge0 = -1; - delete[] this->gg; this->gg = new double[this->npw]; - delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3[this->npw]; - delete[] this->gcar; this->gcar = new ModuleBase::Vector3[this->npw]; + this->gg_cache_storage.reset(new double[this->npw]); + this->gdirect_cache_storage.reset(new ModuleBase::Vector3[this->npw]); + this->gcar_cache_storage.reset(new ModuleBase::Vector3[this->npw]); + this->gg = this->gg_cache_storage.get(); + this->gdirect = this->gdirect_cache_storage.get(); + this->gcar = this->gcar_cache_storage.get(); + // Unique-G data depends on gg, so rebuilding local G data invalidates it. + this->uniqgg_cache_valid.store(false); + this->ig2igg_cache_storage.reset(); + this->gg_uniq_cache_storage.reset(); + this->ig2igg = nullptr; + this->gg_uniq = nullptr; + this->ngg = 0; ModuleBase::Vector3 f; int gamma_num = 0; @@ -182,6 +287,10 @@ void PW_Basis::collect_local_pw() } } } + this->local_pw_cache_valid.store(true); + this->local_pw_cache_signature = this->make_cache_signature(); + ModuleBase::timer::end(this->classname, "collect_local_pw_cache_build"); + ModuleBase::timer::end(this->classname, "collect_local_pw"); return; } @@ -196,45 +305,74 @@ void PW_Basis::collect_uniqgg() { return; } + ModuleBase::timer::start(this->classname, "collect_uniqgg"); + std::lock_guard guard(this->cache_mutex); + if (this->uniqgg_cache_valid.load() + && this->cache_signature_matches(this->uniqgg_cache_signature)) + { + ModuleBase::timer::start(this->classname, "collect_uniqgg_cache_hit"); + this->uniqgg_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_uniqgg_cache_hit"); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); + return; + } + ModuleBase::timer::start(this->classname, "collect_uniqgg_cache_build"); + this->uniqgg_cache_misses.fetch_add(1); this->ig_gge0 = -1; - delete[] this->ig2igg; this->ig2igg = new int [this->npw]; + this->ig2igg_cache_storage.reset(new int[this->npw]); + this->ig2igg = this->ig2igg_cache_storage.get(); - int *sortindex = new int [this->npw];//Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves - double *tmpgg = new double [this->npw];//Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond - double *tmpgg2 = new double [this->npw];//ranking the plane waves by energy size and removing the duplicates - ModuleBase::Vector3 f; - for(int ig = 0 ; ig < this-> npw ; ++ig) + std::vector sortindex(this->npw); // Reconstruct the plane-wave index mapping after sorting by energy. + std::vector tmpgg(this->npw); + std::vector tmpgg2(this->npw); + // Reuse gg when collect_local_pw has already built the same G^2 values. + if (this->local_pw_cache_valid.load() && this->gg != nullptr) { - int isz = this->ig2isz[ig]; - int iz = isz % this->nz; - int is = isz / this->nz; - int ixy = this->is2fftixy[is]; - int ix = ixy / this->fftny; - int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) - { - ix -= this->nx; - } - if (iy >= int(this->ny/2) + 1) - { - iy -= this->ny; - } - if (iz >= int(this->nz/2) + 1) + for(int ig = 0 ; ig < this-> npw ; ++ig) { - iz -= this->nz; + tmpgg[ig] = this->gg[ig]; + if(tmpgg[ig] < 1e-8) + { + this->ig_gge0 = ig; + } } - f.x = ix; - f.y = iy; - f.z = iz; - tmpgg[ig] = f * (this->GGT * f); - if(tmpgg[ig] < 1e-8) + } + else + { + ModuleBase::Vector3 f; + for(int ig = 0 ; ig < this-> npw ; ++ig) { - this->ig_gge0 = ig; + int isz = this->ig2isz[ig]; + int iz = isz % this->nz; + int is = isz / this->nz; + int ixy = this->is2fftixy[is]; + int ix = ixy / this->fftny; + int iy = ixy % this->fftny; + if (ix >= int(this->nx/2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny/2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz/2) + 1) + { + iz -= this->nz; + } + f.x = ix; + f.y = iy; + f.z = iz; + tmpgg[ig] = f * (this->GGT * f); + if(tmpgg[ig] < 1e-8) + { + this->ig_gge0 = ig; + } } } - ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw); - ModuleBase::heapsort(this->npw, tmpgg, sortindex); + ModuleBase::GlobalFunc::ZEROS(sortindex.data(), this->npw); + ModuleBase::heapsort(this->npw, tmpgg.data(), sortindex.data()); int igg = 0; @@ -261,14 +399,16 @@ void PW_Basis::collect_uniqgg() } tmpgg2[igg] = avg_gg / double(avg_n); this->ngg = igg + 1; - delete[] this->gg_uniq; this->gg_uniq = new double [this->ngg]; + this->gg_uniq_cache_storage.reset(new double[this->ngg]); + this->gg_uniq = this->gg_uniq_cache_storage.get(); for(int igg = 0 ; igg < this->ngg ; ++igg) { gg_uniq[igg] = tmpgg2[igg]; } - delete[] sortindex; - delete[] tmpgg; - delete[] tmpgg2; + this->uniqgg_cache_valid.store(true); + this->uniqgg_cache_signature = this->make_cache_signature(); + ModuleBase::timer::end(this->classname, "collect_uniqgg_cache_build"); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); } void PW_Basis::getfftixy2is(int * fftixy2is) const @@ -295,10 +435,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is) const void PW_Basis::set_device(std::string device_) { this->device = std::move(device_); + this->invalidate_cache(); } void PW_Basis::set_precision(std::string precision_) { this->precision = std::move(precision_); + this->invalidate_cache(); } } diff --git a/source/source_basis/module_pw/pw_basis.h b/source/source_basis/module_pw/pw_basis.h index b834cb0e0f4..813e93342e9 100644 --- a/source/source_basis/module_pw/pw_basis.h +++ b/source/source_basis/module_pw/pw_basis.h @@ -9,9 +9,14 @@ #include #include "source_base/module_fft/fft_bundle.h" #include +#include #ifdef __MPI #include "mpi.h" #endif +#include +#include +#include +#include namespace ModulePW { @@ -56,8 +61,21 @@ class PW_Basis { public: + struct CacheStats + { + std::uint64_t local_pw_hits = 0; + std::uint64_t local_pw_misses = 0; + std::uint64_t uniqgg_hits = 0; + std::uint64_t uniqgg_misses = 0; + std::size_t cache_bytes = 0; + }; + std::string classname; PW_Basis(); + // PW_Basis owns FFT/distribution maps through raw pointers, so copying would + // create ambiguous ownership and stale cache pointers. + PW_Basis(const PW_Basis& other) = delete; + PW_Basis& operator=(const PW_Basis& other) = delete; PW_Basis(std::string device_, std::string precision_); virtual ~PW_Basis(); //Init mpi parameters @@ -137,9 +155,74 @@ class PW_Basis //distribute plane waves and grids and set up fft void setuptransform(); + CacheStats get_cache_stats() const; + void reset_cache_stats(); + protected: int *startnsz_per=nullptr;//useless intermediate variable// startnsz_per[ip]: starting is * nz stick in the ip^th proc. + virtual void invalidate_cache() + { + std::lock_guard guard(this->cache_mutex); + this->invalidate_cache_unlocked(); + } + + void clear_owned_cache(); + + // Public gg/gcar/gdirect pointers are non-owning views of these cache buffers. + std::atomic local_pw_cache_valid{false}; + std::atomic uniqgg_cache_valid{false}; + mutable std::mutex cache_mutex; + std::unique_ptr gg_cache_storage; + std::unique_ptr[]> gdirect_cache_storage; + std::unique_ptr[]> gcar_cache_storage; + std::unique_ptr ig2igg_cache_storage; + std::unique_ptr gg_uniq_cache_storage; + std::atomic local_pw_cache_hits{0}; + std::atomic local_pw_cache_misses{0}; + std::atomic uniqgg_cache_hits{0}; + std::atomic uniqgg_cache_misses{0}; + + struct CacheSignature + { + double lat0 = 0.0; + double tpiba = 0.0; + double tpiba2 = 0.0; + int nx = 0; + int ny = 0; + int nz = 0; + int fftnx = 0; + int fftny = 0; + int fftnz = 0; + int npw = 0; + ModuleBase::Matrix3 G; + ModuleBase::Matrix3 GT; + ModuleBase::Matrix3 GGT; + }; + CacheSignature make_cache_signature() const; + bool cache_signature_matches(const CacheSignature& signature) const; + CacheSignature local_pw_cache_signature; + CacheSignature uniqgg_cache_signature; + + virtual void invalidate_cache_unlocked() + { + this->local_pw_cache_valid.store(false); + this->uniqgg_cache_valid.store(false); + this->gg_cache_storage.reset(); + this->gdirect_cache_storage.reset(); + this->gcar_cache_storage.reset(); + this->ig2igg_cache_storage.reset(); + this->gg_uniq_cache_storage.reset(); + this->gg = nullptr; + this->gdirect = nullptr; + this->gcar = nullptr; + this->ig2igg = nullptr; + this->gg_uniq = nullptr; + this->ngg = 0; + this->ig_gge0 = -1; + } + CacheStats get_cache_stats_unlocked() const; + //distribute plane waves to different processors void distribute_g(); diff --git a/source/source_basis/module_pw/pw_basis_k.cpp b/source/source_basis/module_pw/pw_basis_k.cpp index 2c2d02bf927..6418258d3bc 100644 --- a/source/source_basis/module_pw/pw_basis_k.cpp +++ b/source/source_basis/module_pw/pw_basis_k.cpp @@ -6,6 +6,7 @@ #include "source_base/timer.h" #include +#include namespace ModulePW { @@ -21,7 +22,6 @@ PW_Basis_K::~PW_Basis_K() delete[] npwk; delete[] igl2isz_k; delete[] igl2ig_k; - delete[] gk2; #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -44,6 +44,49 @@ PW_Basis_K::~PW_Basis_K() #if defined(__CUDA) || defined(__ROCM) } #endif + this->clear_k_cache_storage(); +} + +void PW_Basis_K::clear_k_cache_storage() +{ + std::lock_guard guard(this->cache_mutex); + this->invalidate_cache_unlocked(); +} + +PW_Basis_K::KCacheStats PW_Basis_K::get_k_cache_stats() const +{ + std::lock_guard guard(this->cache_mutex); + KCacheStats stats; + const auto base_stats = PW_Basis::get_cache_stats_unlocked(); + static_cast(stats) = base_stats; + stats.gcar_hits = this->gcar_cache_hits.load(); + stats.gcar_misses = this->gcar_cache_misses.load(); + stats.gk2_hits = this->gk2_cache_hits.load(); + stats.gk2_misses = this->gk2_cache_misses.load(); + if (this->gcar_cache_valid.load() + && this->gcar != nullptr + && this->npwk_max > 0 + && this->nks > 0) + { + stats.cache_bytes += sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks; + } + if (this->gk_cache_valid.load() + && this->gk2 != nullptr + && this->npwk_max > 0 + && this->nks > 0) + { + stats.cache_bytes += sizeof(double) * this->npwk_max * this->nks; + } + return stats; +} + +void PW_Basis_K::reset_k_cache_stats() +{ + PW_Basis::reset_cache_stats(); + this->gcar_cache_hits.store(0); + this->gcar_cache_misses.store(0); + this->gk2_cache_hits.store(0); + this->gk2_cache_misses.store(0); } void PW_Basis_K::initparameters(const bool gamma_only_in, @@ -101,6 +144,7 @@ void PW_Basis_K::initparameters(const bool gamma_only_in, this->fftnxy = this->fftnx * this->fftny; this->fftnxyz = this->fftnxy * this->fftnz; this->distribution_type = distribution_type_in; + this->invalidate_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -129,6 +173,7 @@ void PW_Basis_K::initparameters(const bool gamma_only_in, void PW_Basis_K::setupIndGk() { + this->invalidate_cache(); // count npwk this->npwk_max = 0; delete[] this->npwk; @@ -198,6 +243,34 @@ void PW_Basis_K::setupIndGk() return; } +ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const +{ + int isz = this->ig2isz[ig]; + int iz = isz % this->nz; + int is = isz / this->nz; + int ix = this->is2fftixy[is] / this->fftny; + int iy = this->is2fftixy[is] % this->fftny; + if (ix >= int(this->nx / 2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny / 2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz / 2) + 1) + { + iz -= this->nz; + } + ModuleBase::Vector3 f; + f.x = ix; + f.y = iy; + f.z = iz; + f = f * this->G; + ModuleBase::Vector3 g_temp_ = this->kvec_c[ik] + f; + return g_temp_; +} + /// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall @@ -249,19 +322,60 @@ void PW_Basis_K::setuptransform() void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_height_in, const double& erf_sigma_in) { - this->erf_ecut = erf_ecut_in; - this->erf_height = erf_height_in; - this->erf_sigma = erf_sigma_in; if (this->npwk_max <= 0) { return; } - delete[] gk2; - delete[] gcar; - this->gk2 = new double[this->npwk_max * this->nks]; - this->gcar = new ModuleBase::Vector3[this->npwk_max * this->nks]; - ModuleBase::Memory::record("PW_B_K::gk2", sizeof(double) * this->npwk_max * this->nks); - ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); + ModuleBase::timer::start(this->classname, "collect_local_pw"); + std::lock_guard guard(this->cache_mutex); + const bool locked_gcar_hit = this->gcar_cache_valid.load() && this->gcar != nullptr; + const bool locked_gk2_hit = this->gk_cache_valid.load() + && this->gk2 != nullptr + && this->erf_ecut == erf_ecut_in + && this->erf_height == erf_height_in + && this->erf_sigma == erf_sigma_in; + if (locked_gcar_hit && locked_gk2_hit) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_cache_hit"); + this->gcar_cache_hits.fetch_add(1); + this->gk2_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw_cache_hit"); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + if (!locked_gcar_hit) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_build_gcar"); + } + if (!locked_gk2_hit) + { + ModuleBase::timer::start(this->classname, "collect_local_pw_build_gk2"); + } + if (locked_gcar_hit) + { + this->gcar_cache_hits.fetch_add(1); + } + else + { + this->gcar_cache_misses.fetch_add(1); + this->k_gcar_cache_storage.reset(new ModuleBase::Vector3[this->npwk_max * this->nks]); + this->gcar = this->k_gcar_cache_storage.get(); + ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); + } + if (locked_gk2_hit) + { + this->gk2_cache_hits.fetch_add(1); + } + else + { + this->gk2_cache_misses.fetch_add(1); + this->k_gk2_cache_storage.reset(new double[this->npwk_max * this->nks]); + this->gk2 = this->k_gk2_cache_storage.get(); + ModuleBase::Memory::record("PW_B_K::gk2", sizeof(double) * this->npwk_max * this->nks); + } + this->erf_ecut = erf_ecut_in; + this->erf_height = erf_height_in; + this->erf_sigma = erf_sigma_in; ModuleBase::Vector3 f; for (int ik = 0; ik < this->nks; ++ik) @@ -291,36 +405,55 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h f.y = iy; f.z = iz; - this->gcar[ik * npwk_max + igl] = f * this->G; - double temp_gk2 = (f + kv) * (this->GGT * (f + kv)); - if (erf_height > 0) + if (!locked_gcar_hit) { - this->gk2[ik * npwk_max + igl] - = temp_gk2 + erf_height / tpiba2 * (1.0 + std::erf((temp_gk2 * tpiba2 - erf_ecut) / erf_sigma)); + this->gcar[ik * npwk_max + igl] = f * this->G; } - else + if (!locked_gk2_hit) { - this->gk2[ik * npwk_max + igl] = temp_gk2; + const double temp_gk2 = (f + kv) * (this->GGT * (f + kv)); + if (erf_height > 0) + { + this->gk2[ik * npwk_max + igl] + = temp_gk2 + erf_height / tpiba2 * (1.0 + std::erf((temp_gk2 * tpiba2 - erf_ecut) / erf_sigma)); + } + else + { + this->gk2[ik * npwk_max + igl] = temp_gk2; + } } } } + if (!locked_gcar_hit) + { + this->sync_gcar_device_cache(); + this->gcar_cache_valid.store(true); + ModuleBase::timer::end(this->classname, "collect_local_pw_build_gcar"); + } + if (!locked_gk2_hit) + { + this->sync_gk2_device_cache(); + this->gk_cache_valid.store(true); + ModuleBase::timer::end(this->classname, "collect_local_pw_build_gk2"); + } + ModuleBase::timer::end(this->classname, "collect_local_pw"); +} + +void PW_Basis_K::sync_gcar_device_cache() +{ #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { if (this->float_data_) { - resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); resmem_sd_op()(this->s_gcar, this->npwk_max * this->nks * 3); - castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); castmem_d2s_h2d_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); } if (this->double_data_) { - resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); resmem_dd_op()(this->d_gcar, this->npwk_max * this->nks * 3); - syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); syncmem_d2d_h2d_op()(this->d_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); @@ -331,9 +464,7 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif if (this->float_data_) { - resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); resmem_sh_op()(this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar"); - castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); castmem_d2s_h2h_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); @@ -341,7 +472,6 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h if (this->double_data_) { this->d_gcar = reinterpret_cast(&this->gcar[0][0]); - this->d_gk2 = this->gk2; } // There's no need to allocate double pointers while in a CPU environment. #if defined(__CUDA) || defined(__ROCM) @@ -349,32 +479,37 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif } -ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const +void PW_Basis_K::sync_gk2_device_cache() { - int isz = this->ig2isz[ig]; - int iz = isz % this->nz; - int is = isz / this->nz; - int ix = this->is2fftixy[is] / this->fftny; - int iy = this->is2fftixy[is] % this->fftny; - if (ix >= int(this->nx / 2) + 1) - { - ix -= this->nx; - } - if (iy >= int(this->ny / 2) + 1) +#if defined(__CUDA) || defined(__ROCM) + if (this->device == "gpu") { - iy -= this->ny; + if (this->float_data_) + { + resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); + castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); + } + if (this->double_data_) + { + resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); + syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); + } } - if (iz >= int(this->nz / 2) + 1) + else { - iz -= this->nz; +#endif + if (this->float_data_) + { + resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); + castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); + } + if (this->double_data_) + { + this->d_gk2 = this->gk2; + } +#if defined(__CUDA) || defined(__ROCM) } - ModuleBase::Vector3 f; - f.x = ix; - f.y = iy; - f.z = iz; - f = f * this->G; - ModuleBase::Vector3 g_temp_ = this->kvec_c[ik] + f; - return g_temp_; +#endif } double& PW_Basis_K::getgk2(const int ik, const int igl) const @@ -510,23 +645,27 @@ double* PW_Basis_K::get_kvec_c_data() const template <> float* PW_Basis_K::get_gcar_data() const { - return this->s_gcar; + std::lock_guard guard(this->cache_mutex); + return this->gcar_cache_valid.load() ? this->s_gcar : nullptr; } template <> double* PW_Basis_K::get_gcar_data() const { - return this->d_gcar; + std::lock_guard guard(this->cache_mutex); + return this->gcar_cache_valid.load() ? this->d_gcar : nullptr; } template <> float* PW_Basis_K::get_gk2_data() const { - return this->s_gk2; + std::lock_guard guard(this->cache_mutex); + return this->gk_cache_valid.load() ? this->s_gk2 : nullptr; } template <> double* PW_Basis_K::get_gk2_data() const { - return this->d_gk2; + std::lock_guard guard(this->cache_mutex); + return this->gk_cache_valid.load() ? this->d_gk2 : nullptr; } -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW diff --git a/source/source_basis/module_pw/pw_basis_k.h b/source/source_basis/module_pw/pw_basis_k.h index f633a30769d..81dac5b2e53 100644 --- a/source/source_basis/module_pw/pw_basis_k.h +++ b/source/source_basis/module_pw/pw_basis_k.h @@ -56,6 +56,14 @@ class PW_Basis_K : public PW_Basis { public: + struct KCacheStats : public PW_Basis::CacheStats + { + std::uint64_t gcar_hits = 0; + std::uint64_t gcar_misses = 0; + std::uint64_t gk2_hits = 0; + std::uint64_t gk2_misses = 0; + }; + PW_Basis_K(); PW_Basis_K(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {classname="PW_Basis_K";} ~PW_Basis_K(); @@ -99,16 +107,42 @@ class PW_Basis_K : public PW_Basis const double& erf_height_in = 0.0, const double& erf_sigma_in = 0.1); + KCacheStats get_k_cache_stats() const; + void reset_k_cache_stats(); + private: + void clear_k_cache_storage(); + void invalidate_cache_unlocked() override + { + PW_Basis::invalidate_cache_unlocked(); + this->gcar_cache_valid.store(false); + this->gk_cache_valid.store(false); + this->k_gcar_cache_storage.reset(); + this->k_gk2_cache_storage.reset(); + this->gcar = nullptr; + this->gk2 = nullptr; + this->d_gcar = nullptr; + this->d_gk2 = nullptr; + } + void sync_gcar_device_cache(); + void sync_gk2_device_cache(); + + std::atomic gcar_cache_valid{false}; + std::atomic gk_cache_valid{false}; + std::unique_ptr[]> k_gcar_cache_storage; + std::unique_ptr k_gk2_cache_storage; + std::atomic gcar_cache_hits{0}; + std::atomic gcar_cache_misses{0}; + std::atomic gk2_cache_hits{0}; + std::atomic gk2_cache_misses{0}; float * s_gk2 = nullptr; double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks] //create igl2isz_k map array for fft void setupIndGk(); // get ig2ixyz_k void get_ig2ixyz_k(); - //calculate G+K, it is a private function + //calculate G+K in cartesian coordinates ModuleBase::Vector3 cal_GplusK_cartesian(const int ik, const int ig) const; - public: template void real2recip(const FPTYPE* in, @@ -280,4 +314,3 @@ class PW_Basis_K : public PW_Basis #endif //PlaneWave_K class #include "./pw_basis_k_big.h" //temporary it will be removed - diff --git a/source/source_basis/module_pw/pw_distributeg.cpp b/source/source_basis/module_pw/pw_distributeg.cpp index a13fc57f9b9..93d9b07d294 100644 --- a/source/source_basis/module_pw/pw_distributeg.cpp +++ b/source/source_basis/module_pw/pw_distributeg.cpp @@ -161,6 +161,7 @@ void PW_Basis::get_ig2isz_is2fftixy( { delete[] this->ig2isz; this->ig2isz = nullptr; // map ig to the z coordinate of this planewave. delete[] this->is2fftixy; this->is2fftixy = nullptr; // map is (index of sticks) to ixy (iy + ix * fftny). + this->invalidate_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { delmem_int_op()(this->d_is2fftixy); @@ -226,6 +227,7 @@ void PW_Basis::get_ig2isz_is2fftixy( syncmem_int_h2d_op()(ig2ixyz_gpu, ig2ixyz.data(), this->npw); } #endif + this->invalidate_cache(); return; } } // namespace ModulePW \ No newline at end of file diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 207320f4268..2ed5044620e 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -15,8 +15,8 @@ namespace ModulePW template void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const { - - if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, + + if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, { const int nst_ = this->nst; const int nz_ = this->nz; @@ -112,7 +112,10 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const template void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { - if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, + + + + if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, { const int nrxx_ = this->nrxx; const int nst_ = this->nst; diff --git a/source/source_basis/module_pw/pw_init.cpp b/source/source_basis/module_pw/pw_init.cpp index 08c676d39f3..9a7d3b06043 100644 --- a/source/source_basis/module_pw/pw_init.cpp +++ b/source/source_basis/module_pw/pw_init.cpp @@ -13,6 +13,7 @@ void PW_Basis:: initmpi( this->poolnproc = poolnproc_in; this->poolrank = poolrank_in; this->pool_world = pool_world_in; + this->invalidate_cache(); } #endif /// @@ -142,6 +143,7 @@ void PW_Basis:: initgrids( this->nz = ibox[2]; this->nxy =this->nx * this->ny; this->nxyz = this->nxy * this->nz; + this->invalidate_cache(); delete[] ibox; return; @@ -203,6 +205,7 @@ void PW_Basis:: initgrids( MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world); #endif this->gridecut_lat -= 1e-6; + this->invalidate_cache(); delete[] ibox; return; @@ -240,6 +243,7 @@ void PW_Basis:: initparameters( this->ggecut = this->gridecut_lat; } this->distribution_type = distribution_type_in; + this->invalidate_cache(); } // Set parameters about full planewave, used only in OFDFT for now. sunliang added 2022-08-30 @@ -251,5 +255,6 @@ void PW_Basis::setfullpw( this->full_pw = inpt_full_pw; this->full_pw_dim = inpt_full_pw_dim; if (!this->full_pw) this->full_pw_dim = 0; + this->invalidate_cache(); +} } -} \ No newline at end of file diff --git a/source/source_basis/module_pw/test/test-big.cpp b/source/source_basis/module_pw/test/test-big.cpp index f1c2082d0b2..1f56597ece3 100644 --- a/source/source_basis/module_pw/test/test-big.cpp +++ b/source/source_basis/module_pw/test/test-big.cpp @@ -53,7 +53,7 @@ TEST_F(PWTEST,test_big) pwktest.initgrids(lat0,latvec, pwtest.nx, pwtest.ny, pwtest.nz); pwtest.initparameters(gamma_only,wfcecut,distribution_type,xprime); pwktest.initparameters(gamma_only,wfcecut,nks,kvec_d,distribution_type, xprime); - static_cast(pwtest).setuptransform(); + pwtest.ModulePW::PW_Basis::setuptransform(); pwktest.setuptransform(); EXPECT_EQ(pwtest.nx%2, 0); EXPECT_EQ(pwtest.ny%2, 0); @@ -85,7 +85,7 @@ TEST_F(PWTEST,test_big) class TestPW_Basis_Big : public ::testing::Test { public: - ModulePW::PW_Basis_Big pwtest = ModulePW::PW_Basis_Big(); + ModulePW::PW_Basis_Big pwtest; }; // Test the function with nproc = 0 (bx and by) @@ -157,4 +157,4 @@ TEST_F(TestPW_Basis_Big, BzNprocNoResultTest) { int nproc = 5; pwtest.autoset_big_cell_size(b_size, nc_size, nproc); EXPECT_EQ(b_size, 3); -} \ No newline at end of file +} diff --git a/source/source_basis/module_pw/test/test1-1-1.cpp b/source/source_basis/module_pw/test/test1-1-1.cpp index 3eb9d8fd5e4..4be297f276e 100644 --- a/source/source_basis/module_pw/test/test1-1-1.cpp +++ b/source/source_basis/module_pw/test/test1-1-1.cpp @@ -29,6 +29,36 @@ TEST_F(PWTEST,test1_1_1) pwtest.initgrids(lat0, latvec, wfcecut); pwtest.initparameters(gamma_only, wfcecut, distribution_type,xprime); pwtest.setuptransform(); + pwtest.reset_cache_stats(); + pwtest.collect_local_pw(); + pwtest.collect_uniqgg(); + auto stats_after_build = pwtest.get_cache_stats(); + EXPECT_EQ(stats_after_build.local_pw_misses, 1); + EXPECT_EQ(stats_after_build.uniqgg_misses, 1); + double* gg_ptr = pwtest.gg; + int* ig2igg_ptr = pwtest.ig2igg; + double* gguniq_ptr = pwtest.gg_uniq; + const int ngg_before = pwtest.ngg; + const double gg_sample = pwtest.gg[0]; + pwtest.collect_local_pw(); + pwtest.collect_uniqgg(); + EXPECT_EQ(pwtest.gg, gg_ptr); + EXPECT_EQ(pwtest.ig2igg, ig2igg_ptr); + EXPECT_EQ(pwtest.gg_uniq, gguniq_ptr); + EXPECT_EQ(pwtest.ngg, ngg_before); + EXPECT_DOUBLE_EQ(pwtest.gg[0], gg_sample); + auto stats_after_hit = pwtest.get_cache_stats(); + EXPECT_EQ(stats_after_hit.local_pw_hits, 1); + EXPECT_EQ(stats_after_hit.uniqgg_hits, 1); + EXPECT_GT(stats_after_hit.cache_bytes, 0); + pwtest.initparameters(gamma_only, wfcecut, distribution_type, xprime); + EXPECT_EQ(pwtest.gg, nullptr); + EXPECT_EQ(pwtest.gdirect, nullptr); + EXPECT_EQ(pwtest.gcar, nullptr); + EXPECT_EQ(pwtest.ig2igg, nullptr); + EXPECT_EQ(pwtest.gg_uniq, nullptr); + EXPECT_EQ(pwtest.get_cache_stats().cache_bytes, 0); + pwtest.setuptransform(); pwtest.collect_local_pw(); pwtest.collect_uniqgg(); ModuleBase::Matrix3 GT,G,GGT; @@ -229,4 +259,4 @@ TEST_F(PWTEST,test1_1_1) delete[] irindex; -} \ No newline at end of file +} diff --git a/source/source_basis/module_pw/test_serial/CMakeLists.txt b/source/source_basis/module_pw/test_serial/CMakeLists.txt index 52e594afb99..07179b2bd1c 100644 --- a/source/source_basis/module_pw/test_serial/CMakeLists.txt +++ b/source/source_basis/module_pw/test_serial/CMakeLists.txt @@ -34,3 +34,18 @@ AddTest( LIBS parameter ${math_libs} planewave_serial device base SOURCES pw_basis_k_test.cpp ) + +add_executable( + MODULE_PW_cache_bench_serial + pw_cache_bench.cpp +) + +target_link_libraries( + MODULE_PW_cache_bench_serial + parameter + ${math_libs} + planewave_serial + device + base + Threads::Threads +) diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 84932bae2ff..4b1ec1f0aed 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include "source_base/timer.h" /************************************************ * serial unit test of functions in pw_basis.cpp @@ -183,9 +184,59 @@ TEST_F(PWBasisKTEST, CollectLocalPW) const bool xprime_in = true; basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in,kvec_d_in, distribution_type_in, xprime_in); EXPECT_NO_THROW(basis_k.setuptransform()); + basis_k.reset_k_cache_stats(); EXPECT_NO_THROW(basis_k.collect_local_pw()); + ASSERT_GT(basis_k.npwk[0], 0); + auto* gk2_ptr = basis_k.get_gk2_data(); + auto* gcar_ptr = basis_k.get_gcar_data(); + const double gk2_sample = basis_k.getgk2(0,0); + const auto stats_after_build = basis_k.get_k_cache_stats(); + EXPECT_EQ(stats_after_build.gcar_misses, 1); + EXPECT_EQ(stats_after_build.gk2_misses, 1); + EXPECT_NO_THROW(basis_k.collect_local_pw()); + EXPECT_EQ(basis_k.get_gk2_data(), gk2_ptr); + EXPECT_EQ(basis_k.get_gcar_data(), gcar_ptr); + EXPECT_DOUBLE_EQ(basis_k.getgk2(0,0), gk2_sample); + EXPECT_NO_THROW(basis_k.collect_local_pw(1.0, 0.5, 0.2)); + EXPECT_EQ(basis_k.get_gcar_data(), gcar_ptr); + const auto stats_after_hits = basis_k.get_k_cache_stats(); + EXPECT_EQ(stats_after_hits.gcar_hits, 2); + EXPECT_EQ(stats_after_hits.gcar_misses, 1); + EXPECT_EQ(stats_after_hits.gk2_hits, 1); + EXPECT_EQ(stats_after_hits.gk2_misses, 2); + EXPECT_GT(stats_after_hits.cache_bytes, 0); + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + EXPECT_EQ(basis_k.gcar, nullptr); + EXPECT_EQ(basis_k.gk2, nullptr); + EXPECT_EQ(basis_k.k_gcar_cache_storage, nullptr); + EXPECT_EQ(basis_k.k_gk2_cache_storage, nullptr); + EXPECT_EQ(basis_k.get_gcar_data(), nullptr); + EXPECT_EQ(basis_k.get_gk2_data(), nullptr); + EXPECT_EQ(basis_k.get_k_cache_stats().cache_bytes, 0); EXPECT_EQ(basis_k.npw,3695); EXPECT_EQ(basis_k.npwk_max,2721); } - +TEST_F(PWBasisKTEST, CollectLocalPWRecordsTimers) +{ + ModuleBase::timer::timer_pool.clear(); + ModulePW::PW_Basis_K basis_k(device_flag, precision_double); + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut = 10.0; + basis_k.initgrids(lat0, latvec, gridecut); + const bool gamma_only_in = true; + const double gk_ecut_in = 11.0; + const int nks_in = 3; + const ModuleBase::Vector3 kvec_d_in[3] = { {0.0, 0.0, 0.0}, {0.1, 0.2, 0.3}, {0.4, 0.5, 0.6} }; + const int distribution_type_in = 1; + const bool xprime_in = true; + basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in); + basis_k.setuptransform(); + basis_k.collect_local_pw(); + const auto& timer_pool = ModuleBase::timer::timer_pool[basis_k.classname]; + EXPECT_TRUE(timer_pool.count("collect_local_pw")); + EXPECT_GE(timer_pool.at("collect_local_pw").calls, 1u); +} diff --git a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp index ea678b9d97c..4692ed43e30 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_test.cpp @@ -2,6 +2,7 @@ #include "source_base/global_function.h" #include "source_base/constants.h" #include "source_base/matrix3.h" +#include "source_base/timer.h" /************************************************ * serial unit test of functions in pw_basis.cpp @@ -362,3 +363,91 @@ TEST_F(PWBasisTEST,CollectUniqgg) pwb.collect_uniqgg(); EXPECT_EQ(pwb.ngg,78); } + +TEST_F(PWBasisTEST, CacheStorageClearedOnParameterChange) +{ + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut=10.0; + bool gamma_only_in = true; + double pwecut_in = 11.0; + int distribution_type_in = 2; + bool xprime_in = true; + pwb.initgrids(lat0,latvec,gridecut); + pwb.initparameters(gamma_only_in,pwecut_in,distribution_type_in,xprime_in); + EXPECT_NO_THROW(pwb.setuptransform()); + pwb.collect_local_pw(); + pwb.collect_uniqgg(); + EXPECT_GT(pwb.get_cache_stats().cache_bytes, 0); + pwb.initparameters(gamma_only_in,pwecut_in,distribution_type_in,xprime_in); + EXPECT_EQ(pwb.gg_cache_storage, nullptr); + EXPECT_EQ(pwb.gdirect_cache_storage, nullptr); + EXPECT_EQ(pwb.gcar_cache_storage, nullptr); + EXPECT_EQ(pwb.ig2igg_cache_storage, nullptr); + EXPECT_EQ(pwb.gg_uniq_cache_storage, nullptr); + EXPECT_EQ(pwb.get_cache_stats().cache_bytes, 0); +} + +TEST_F(PWBasisTEST, CacheSignatureRejectsChangedLattice) +{ + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut=10.0; + bool gamma_only_in = true; + double pwecut_in = 11.0; + int distribution_type_in = 2; + bool xprime_in = true; + pwb.initgrids(lat0,latvec,gridecut); + pwb.initparameters(gamma_only_in,pwecut_in,distribution_type_in,xprime_in); + EXPECT_NO_THROW(pwb.setuptransform()); + pwb.collect_local_pw(); + int changed_ig = -1; + for (int ig = 0; ig < pwb.npw; ++ig) + { + if (std::abs(pwb.gdirect[ig].x) > 1e-12) + { + changed_ig = ig; + break; + } + } + ASSERT_GE(changed_ig, 0); + const double old_gg = pwb.gg[changed_ig]; + pwb.collect_local_pw(); + EXPECT_EQ(pwb.get_cache_stats().local_pw_hits, 1); + EXPECT_EQ(pwb.get_cache_stats().local_pw_misses, 1); + + pwb.G.e11 *= 1.1; + pwb.GGT = pwb.G * pwb.GT; + pwb.collect_local_pw(); + EXPECT_EQ(pwb.get_cache_stats().local_pw_hits, 1); + EXPECT_EQ(pwb.get_cache_stats().local_pw_misses, 2); + EXPECT_NE(pwb.gg[changed_ig], old_gg); +} + +TEST_F(PWBasisTEST, CacheCollectionRecordsTimers) +{ + ModuleBase::timer::timer_pool.clear(); + double lat0 = 1.8897261254578281; + ModuleBase::Matrix3 latvec(10.0,0.0,0.0, + 0.0,10.0,0.0, + 0.0,0.0,10.0); + double gridecut = 10.0; + bool gamma_only_in = true; + double pwecut_in = 11.0; + int distribution_type_in = 2; + bool xprime_in = true; + pwb.initgrids(lat0, latvec, gridecut); + pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in); + pwb.setuptransform(); + pwb.collect_local_pw(); + pwb.collect_uniqgg(); + const auto& timer_pool = ModuleBase::timer::timer_pool[pwb.classname]; + EXPECT_TRUE(timer_pool.count("collect_local_pw")); + EXPECT_TRUE(timer_pool.count("collect_uniqgg")); + EXPECT_GE(timer_pool.at("collect_local_pw").calls, 1u); + EXPECT_GE(timer_pool.at("collect_uniqgg").calls, 1u); +} diff --git a/source/source_basis/module_pw/test_serial/pw_cache_bench.cpp b/source/source_basis/module_pw/test_serial/pw_cache_bench.cpp new file mode 100644 index 00000000000..b1e2ae7bc8d --- /dev/null +++ b/source/source_basis/module_pw/test_serial/pw_cache_bench.cpp @@ -0,0 +1,153 @@ +#include "source_base/matrix3.h" +#include "source_base/timer.h" + +#include "../pw_basis.h" +#include "../pw_basis_k.h" + +#include +#include +#include +#include + +#ifdef __MPI +#include "mpi.h" +#endif + +namespace +{ + +using Clock = std::chrono::steady_clock; + +template +double measure_seconds(Func&& func) +{ + const auto start = Clock::now(); + func(); + const auto end = Clock::now(); + return std::chrono::duration(end - start).count(); +} + +void print_metric(const std::string& name, const double value) +{ + std::cout << "METRIC " << name << " " << std::fixed << std::setprecision(9) << value << '\n'; +} + +void print_timer_metric(const std::string& class_name, const std::string& timer_name) +{ + const auto class_it = ModuleBase::timer::timer_pool.find(class_name); + if (class_it == ModuleBase::timer::timer_pool.end()) + { + return; + } + const auto timer_it = class_it->second.find(timer_name); + if (timer_it == class_it->second.end()) + { + return; + } + print_metric("timer." + class_name + "." + timer_name + ".seconds", timer_it->second.cpu_second); + print_metric("timer." + class_name + "." + timer_name + ".calls", static_cast(timer_it->second.calls)); +} + +void bench_pw_basis() +{ + constexpr int repeat_calls = 2000; + ModuleBase::timer::timer_pool.clear(); + + ModulePW::PW_Basis basis; + const ModuleBase::Matrix3 latvec(1, 0, 0, 0, 1, 0, 0, 0, 1); + const double lat0 = 10.0; + const double wfcecut = 50.0; + const double rhoecut = 4.0 * wfcecut; + const int distribution_type = 1; + + basis.initgrids(lat0, latvec, rhoecut); + basis.initparameters(false, wfcecut, distribution_type, true); + + print_metric("PW_Basis.setuptransform.wall", measure_seconds([&]() { basis.setuptransform(); })); + print_metric("PW_Basis.collect_local_pw.first.wall", measure_seconds([&]() { basis.collect_local_pw(); })); + print_metric("PW_Basis.collect_local_pw.repeat.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_local_pw(); + } + })); + print_metric("PW_Basis.collect_uniqgg.first.wall", measure_seconds([&]() { basis.collect_uniqgg(); })); + print_metric("PW_Basis.collect_uniqgg.repeat.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_uniqgg(); + } + })); + + print_timer_metric("PW_Basis", "setuptransform"); + print_timer_metric("PW_Basis", "collect_local_pw"); + print_timer_metric("PW_Basis", "collect_local_pw_cache_hit"); + print_timer_metric("PW_Basis", "collect_local_pw_cache_build"); + print_timer_metric("PW_Basis", "collect_uniqgg"); + print_timer_metric("PW_Basis", "collect_uniqgg_cache_hit"); + print_timer_metric("PW_Basis", "collect_uniqgg_cache_build"); +} + +void bench_pw_basis_k() +{ + constexpr int repeat_calls = 2000; + ModuleBase::timer::timer_pool.clear(); + + ModulePW::PW_Basis_K basis("cpu", "double"); + const ModuleBase::Matrix3 latvec(10.0, 0.0, 0.0, + 0.0, 10.0, 0.0, + 0.0, 0.0, 10.0); + const double lat0 = 1.8897261254578281; + const double gridecut = 10.0; + const bool gamma_only = true; + const double gk_ecut = 11.0; + const int nks = 3; + const ModuleBase::Vector3 kvec_d[3] = {{0.0, 0.0, 0.0}, {0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}; + const int distribution_type = 1; + const bool xprime = true; + + basis.initgrids(lat0, latvec, gridecut); + basis.initparameters(gamma_only, gk_ecut, nks, kvec_d, distribution_type, xprime); + + print_metric("PW_Basis_K.setuptransform.wall", measure_seconds([&]() { basis.setuptransform(); })); + print_metric("PW_Basis_K.collect_local_pw.first.wall", measure_seconds([&]() { basis.collect_local_pw(); })); + print_metric("PW_Basis_K.collect_local_pw.repeat.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_local_pw(); + } + })); + print_metric("PW_Basis_K.collect_local_pw.gk2_rebuild.wall", + measure_seconds([&]() { + for (int i = 0; i < repeat_calls; ++i) + { + basis.collect_local_pw(1.0, 0.5, 0.2); + } + })); + + print_timer_metric("PW_Basis_K", "setuptransform"); + print_timer_metric("PW_Basis_K", "collect_local_pw"); + print_timer_metric("PW_Basis_K", "collect_local_pw_cache_hit"); + print_timer_metric("PW_Basis_K", "collect_local_pw_build_gcar"); + print_timer_metric("PW_Basis_K", "collect_local_pw_build_gk2"); +} + +} // namespace + +int main() +{ +#ifdef __MPI + int argc = 0; + char** argv = nullptr; + MPI_Init(&argc, &argv); +#endif + bench_pw_basis(); + bench_pw_basis_k(); +#ifdef __MPI + MPI_Finalize(); +#endif + return 0; +} diff --git a/work_docs/feat_cache_reuse_optimization_process_report_2026-06-06.md b/work_docs/feat_cache_reuse_optimization_process_report_2026-06-06.md new file mode 100644 index 00000000000..b8bf2b3a33a --- /dev/null +++ b/work_docs/feat_cache_reuse_optimization_process_report_2026-06-06.md @@ -0,0 +1,301 @@ +# `feat/cache-reuse` 优化过程报告 + +日期:2026-06-06 +分支:`feat/cache-reuse` +基线:`develop` +范围:`source/source_basis/module_pw` + +## 1. 为什么写这份报告 + +与其说这是”优化前后性能对比”,不如说它想回答一个问题:**这条分支上的十几个 commit 到底是怎么一步步走到今天的?** + +每个提交当时看到了什么问题、上一版哪里还不够、这次改了什么、后来又为什么继续修——把这些串起来,比只看加速比更能看出一次优化是怎么”长出来”的。 + +先放上最终数据,有个印象: + +| 重复调用路径 | 加速比(中位数) | +|---|---| +| `PW_Basis.collect_local_pw` | ~255.5x | +| `PW_Basis.collect_uniqgg` | ~2284.4x | +| `PW_Basis_K.collect_local_pw` | ~342.7x | +| `PW_Basis_K.collect_local_pw(1.0, 0.5, 0.2)` | ~463.0x | + +数值看起来很夸张,但核心逻辑其实很简单:**参数没变就别重复算**。这条分支从头到尾就在做这一件事。 + +## 2. 四个阶段,一条主线 + +回看提交历史,优化过程大体可以分成四段: + +1. **先让缓存跑起来**——在 `PW_Basis` / `PW_Basis_K` 里把重复构建的数据存下来,验证”能不能复用”这件事本身成立 +2. **再从能用变成稳**——处理生命周期、失效条件、对象所有权和测试覆盖,不让缓存”能用但不敢用” +3. **再搭好验证链路**——从单元测试到真实 benchmark,跑脚本、采数据、出报告,让优化效果可量化 +4. **最后收掉边界问题**——锁、CI 链接、晶格变化后的误命中,一个不漏 + +下面按 commit 顺序逐一展开。 + +## 3. 逐个提交看过来 + +### 3.1 `772c9ce67` — 先打好认知基础 + +这不是一个性能提交,更像是在动手改 `module_pw` 之前,先把平面波分配和数据路径的来龙去脉理清楚,写成笔记沉淀下来。后面所有改动都建立在这个理解之上——优化不是闷头硬改,而是先搞清楚自己在动什么。 + +--- + +### 3.2 `27b8d8102` — 第一版缓存,从 0 到 1 + +**当时的状态是**:`collect_local_pw()`、`collect_uniqgg()`、`PW_Basis_K::collect_local_pw()` 每次调用都从头分配、重新构建——`gg`、`gdirect`、`gcar`、`ig2igg`、`gg_uniq`、`gk2`,一个不落。但参数没变的时候,这些数据明明可以留着下次直接用的。 + +**这一版的做法很直接**: + +- `PW_Basis` 里加了 `local_pw_cache_valid`、`uniqgg_cache_valid` 和 `invalidate_cache()`——缓存有效时直接返回,不用重建 +- `collect_uniqgg()` 还能蹭 `collect_local_pw()` 已经算好的 `gg`,省掉重复劳动 +- `PW_Basis_K` 也如法炮制,加 `gk_cache_valid`,`erf_*` 参数没变就跳过重建 +- 在初始化路径上补上 `invalidate_cache()`,首批测试验证重复调用后指针和数据都不变 + +**这步解决的核心问题是**:证明了”重复调用不再重复构建”这个方向是走得通的。第一次把”每次必重建”改成了”命中就不建”。 + +**没来得及处理的**:命中条件还是粗糙的布尔值;拷贝和对象所有权的风险没动;多线程还没考虑;真实 workload 下的收益也不知道。 + +> 一句话:从”没有缓存”到”有缓存”的起点,功能原型立住了。 + +--- + +### 3.3 `3614ad20a` — 让性能可测、可复现 + +第一版缓存跑起来了,但只能靠单元测试说”嗯,指针没变”。没有真实 workload,就回答不了”到底快了多少”这个问题。 + +这一步把大量精力放在了 `homework_docs/test_cases/` 上——写脚本、采数据、搭起从真实 case 到计时解析到批量跑数的完整链路。 + +**这步的意义是**:优化不光要在代码层面成立,还得能在实验层面重复验证。后面的所有决策都建立在这套体系之上。 + +> 一句话:从”实现缓存”到”建立性能验证体系”的过渡。 + +--- + +### 3.4 `b6f6f9409` — 把 timer 还回去 + +之前为了看缓存收益,在代码里埋了不少实验性 timer。这些 timer 对调试有用,但长期看对主代码侵入太大。这个提交把 `timer.cpp` 的行为恢复到接近仓库原本的状态——开始从”实验验证模式”往”准备合入”收敛了。 + +--- + +### 3.5 `9f9ab1eba` — benchmark 脚本化 + +之前 benchmark 多少靠手工跑,偶然性大。这个提交搞了一套 task8 cache benchmark 的采集脚本、README 和汇总工具,让 baseline 和 cache 两组数据可以统一采集、统一比较。它没改 `module_pw` 一行代码,但后续所有”优化是否有效”的证据都来自这里。 + +--- + +### 3.6 `faa843d01` — benchmark 数据补齐 + +在脚本基础上,补齐了 cache-reuse 的跑数结果、日志和汇总文件。这是分支第一次有了比较完整的”缓存版本”跑数记录。 + +--- + +### 3.7 `542e31612` — 合并基线,有对比才有伤害 + +光有优化版数据不够,得和 `develop` 基线对齐才能说”快了多少”。这个提交做了这件事——从”只知道自己快不快”转向”知道相对基线快多少”。 + +--- + +### 3.8 `4d2415134` / `3.9 `faf2b03de` — 文档沉淀 + +这两提交把实验记录、阶段结果和技术说明逐步整理成文档,让优化不只是代码改动,还有了完整的过程说明。进入了”整理可交付材料”的阶段。 + +--- + +### 3.10 `a85f18044` — 做减法 + +前面积累了不少实验材料,分支历史里临时文件偏多。这个提交清理了 `homework_docs` 里不再需要的内容,控制无关文件膨胀,让提交面收敛。 + +--- + +### 3.11 `617b03286` — 从”能用”到”可靠” + +第一版缓存虽然跑通了,但稍微往深处看,问题不少: + +- `PW_Basis` 有自定义拷贝构造,复制一份对象出来,缓存指针可能悬空,所有权也说不清 +- `invalidate_cache()` 只把 valid 标志置 false,但公开指针不一定清空 +- `get_cache_stats()` 判断缓存字节数不够严格,指针都空了还可能报告”缓存存在” +- `PW_Basis_K::setupIndGk()` 里还有重复扫描,效率没拉满 + +**这个提交的改法很干脆:** + +- 删掉 `PW_Basis` 的拷贝构造和赋值操作,明确说”这对象不该被复制” +- 强化 `invalidate_cache_unlocked()`:`gg`、`gdirect`、`gcar`、`ig2igg`、`gg_uniq`、`ig_gge0` 全都清干净 +- `get_cache_stats()` 只有在缓存**确实有效且指针非空**时才计数字节数 +- `PW_Basis_K` 里明确了谁是谁的所有者——`k_gcar_cache_storage` / `k_gk2_cache_storage` 就是公开 `gcar` / `gk2` 的持有者,缓存无效时 `get_gcar_data()` / `get_gk2_data()` 返回 `nullptr` +- `setupIndGk()` 复用已选中的 `ig` 列表,减少重复 cutoff 扫描 +- 顺便修了 `test-big.cpp` 里一个危险的切片写法 + +**这一步做完了,缓存就不只是”布尔值层面的命中”了**——对象所有权清楚了,失效后状态干净了,接口返回可靠了。 + +还没收尾的:并发锁保护还需要继续收紧;`setupIndGk()` 的优化后来部分被回退过;晶格变化后缓存该不该命中,判断还不够严格。 + +> 一句话:这次改完,缓存实现才算从原型长成了工程可维护的样子。 + +--- + +### 3.12 `f25cbe766` — 跟上主线 + +同步上游 `develop`。不做这一步,后面的修复就是在过时的基线上打补丁。事实上,后续几个修复提交都发生在这之后——合流本身暴露了一些集成问题。 + +--- + +### 3.13 `4e79b28bd` — 去掉实验痕迹 + +之前为了看缓存行为,在 `collect_local_pw()`、`collect_uniqgg()`、`setupIndGk()` 等地方加了额外的 timer。实验归实验,长期留在代码里不合适。 + +这个提交把新增的 timer 全移掉了——不是说计时没用,而是实验性埋点不应该带进主逻辑。 + +> 标志性的一步:验证过了,开始收敛。 + +--- + +### 3.14 `913a31504` — 给缓存加把锁 + +**之前的做法有个竞态风险**:判断缓存是否命中的逻辑有一部分在加锁之前——先看一眼 `cache_valid`,再决定要不要进锁。另一个线程可能就在这一眼和一锁之间把缓存状态改了。 + +读接口也一样:`get_cache_stats()`、`get_gcar_data()`、`get_gk2_data()` 不加锁也可能读到不一致的状态。 + +**这个提交的做法:** + +- `invalidate_cache()` 变成:上锁 → 调 `invalidate_cache_unlocked()` +- 新增 `invalidate_cache_unlocked()` 和 `get_cache_stats_unlocked()` 两个内部接口 +- 读写路径全部纳入锁保护,去掉了之前的”乐观快速返回” +- `collect_local_pw()` 重建时,不仅要把自己的数据清掉,连 `ig2igg` 和 `gg_uniq` 的缓存 storage 也一并清空 + +**这步解决的核心问题是**:让”缓存有效位”和”缓存本体”真正同步,不再有”valid 说有效但数据是旧的”这种半吊子状态。 + +唯一没堵上的口子是:如果晶格矩阵被直接改掉(不走 `invalidate_cache()`),缓存仍然可能误命中——这留给了下一步。 + +> 一句话:缓存从”逻辑正确”进化到了”并发下也正确”。 + +--- + +### 3.15 `8c622c2db` — 不让 CI 卡住 + +有些虚函数实现写在 `.cpp` 里,在某些链接场景下可能让虚表或符号解析出问题。这个提交把 `invalidate_cache()` 和 `invalidate_cache_unlocked()` 的默认实现挪到头文件里内联定义,`PW_Basis_K` 的版本也一样处理。 + +不是性能修复,但很重要——优化做得再好,CI 过不了也合不进主线。 + +--- + +### 3.16 `4faaadcb1` — 最后一块拼图:什么时候绝不能复用 + +前面的缓存命中条件说到底就两样:`cache_valid` + 部分参数检查。但如果晶格、倒格矢矩阵、FFT 网格在对象内部变了——没走显式失效的话——旧缓存就可能被静默复用。 + +这类 bug 最危险:它不崩、不报错、编译也通过,只是默默地给你一个不对的结果。 + +**这个提交的做法是在 `PW_Basis` 里引入 `CacheSignature`:** + +- 把 `lat0`、`tpiba`、`tpiba2`、`nx/ny/nz`、`fftnx/fftny/fftnz`、`npw`、`G/GT/GGT` 这些真正决定缓存内容的状态打包成一个签名 +- 每次命中缓存前,不光看 `cache_valid`,还要拿当前状态和签名比对一遍 +- 签名不匹配?老老实实重建,然后记下新的签名 + +对应的测试也很直白:先命中一次,然后故意改掉晶格矩阵,验证缓存确实 miss、结果也确实变了。 + +> 如果说前面的提交解决的是”缓存怎么复用”,那这一提交解决的就是”什么时候绝不能复用”。性能再好,跑出错误结果也毫无意义。 + +## 4. 回过头看这条问题收敛链 + +如果把整条分支看作”问题一步步暴露、一步步补上”的过程,路线大致是这样的: + +1. **先看到重复构建很浪费** → 做了第一版 `cache_valid` 复用 +2. **然后发现”能复用”不等于”工程上安全”** → 处理失效路径、指针清空、拷贝语义、测试覆盖 +3. **再发现只看单测不够** → 搭 benchmark 脚本、baseline、汇总流程 +4. **再发现实验 timer 不该留着** → 移除额外 timer,收敛代码表面 +5. **接着暴露出锁问题** → 读写全上锁,storage 和 valid 位同步失效 +6. **最后发现晶格变化后会误命中** → 引入 `CacheSignature`,命中条件从布尔值升级为签名匹配 + +每一步的触发因素都是前一步的”还不够”,而不是一开始就规划好的。这大概就是所谓”工程化迭代”的样子——先抓住主要瓶颈,再补边界,再补验证,最后收敛正确性。 + +## 5. 每一步值在哪 + +### 5.1 第一版缓存(`27b8d8102`)为什么值钱 + +它证明了方向是对的。`collect_local_pw`、`collect_uniqgg`、`PW_Basis_K::collect_local_pw` 这些函数确实是可缓存的,而不是”看起来像热点,实际上不能复用”。没有这一步,后面所有锁、签名、测试和基准都无从谈起。 + +### 5.2 中间几轮工程打磨为什么必不可少 + +`617b03286`、`913a31504`、`8c622c2db` 这几提交看起来不像加速比那么”显眼”,但少了哪一步都危险: + +- 不清空 storage → valid 位失效了但旧指针还在,半死不活的状态最坑 +- 不处理拷贝语义 → 缓存所有权一锅粥 +- 不把判断放进锁 → 命中到不一致状态 +- CI 链接不过 → 优化再好也进不了主线 + +这些提交把”优化代码”变成了”可维护、可测试、可集成的优化代码”。 + +### 5.3 最后一个正确性补丁为什么最关键 + +`4faaadcb1` 堵住的是最危险的一类 bug——不是崩溃,不是编译不过,而是**静默地复用了错的缓存**。没有 signature 机制,性能数字再好看,结果也不可信。这一步是在给整条优化链”盖章”:它不只是快,而且应该是对的。 + +## 6. benchmark 数据怎么说 + +这次 micro-benchmark 的结果和分支上的演进逻辑是对得上的: + +### 6.1 `PW_Basis.collect_local_pw` + +| 版本 | 中位数耗时 | +|---|---| +| `develop` | 0.0882 s | +| `feat/cache-reuse` | 0.0003 s | +| **加速比** | **~255.5x** | + +本地 PW 数据不再重复构建,直接命中缓存,效果立竿见影。 + +### 6.2 `PW_Basis.collect_uniqgg` + +| 版本 | 中位数耗时 | +|---|---| +| `develop` | 0.7546 s | +| `feat/cache-reuse` | 0.0003 s | +| **加速比** | **~2284.4x** | + +不需要再每次排序、去重,稳定命中缓存。这也是整条分支收益最大的单点。 + +### 6.3 `PW_Basis_K.collect_local_pw` + +| 版本 | 中位数耗时 | +|---|---| +| `develop` | 0.1091 s | +| `feat/cache-reuse` | 0.0003 s | +| **加速比** | **~342.7x** | + +K 点路径上的 `gcar` / `gk2` 缓存确实成立。 + +### 6.4 `PW_Basis_K.collect_local_pw(1.0, 0.5, 0.2)` + +| 版本 | 中位数耗时 | +|---|---| +| `develop` | 0.1930 s | +| `feat/cache-reuse` | 0.0004 s | +| **加速比** | **~463.0x** | + +这项结果很有意思——`erf` 参数变了,`gk2` 确实得重建,但 **`gcar` 还在缓存里**。这正是缓存粒度做细之后的收益:不是”全有或全无”,而是该失效的失效、该复用的复用。 + +## 7. 总体评价 + +回过头看 `feat/cache-reuse` 这条分支,它不是一个”一次写完大改”的提交,而是一个典型的工程化迭代过程: + +1. 从热点识别出发,先建立缓存复用能力 +2. 再花功夫解决生命周期、失效和所有权问题 +3. 再搭 benchmark、baseline 和文档链路 +4. 然后去掉实验性噪音,让实现收敛 +5. 最后补上锁一致性和晶格签名校验,完成正确性闭环 + +最终结果体现为:`module_pw` 的重复调用路径从”每次都重建”变成了”首轮构建,后续命中”,性能收益显著,且状态变化时不会误命中旧缓存。 + +但这条分支的价值不只是”跑得更快”——说起来可能是这四件事: + +- **知道哪里可以缓存** +- **知道什么时候必须失效** +- **知道怎样证明它真的更快** +- **知道怎样保证它快得是对的** + +## 8. 如果要在答辩或汇报上说清楚这件事 + +可以试试这三句话: + +1. **这次优化的核心不是改 FFT 本身,而是让 `module_pw` 在参数没变的时候别重复算**——`gg`、`gdirect`、`gcar`、`ig2igg`、`gg_uniq`、`gk2`,该省的全省掉。 +2. **优化过程不是一次拍脑袋写完的,而是”先做复用、再补失效、再补锁、最后补状态签名”,一步步收敛的。** +3. **最终收益集中在重复调用路径,`PW_Basis.collect_uniqgg` 中位数提升约 2284 倍——说明缓存复用确实命中了真正的热点。** diff --git a/work_docs/module_pw_cache_reuse_perf_compare_2026-06-06.md b/work_docs/module_pw_cache_reuse_perf_compare_2026-06-06.md new file mode 100644 index 00000000000..2f022dbd6d8 --- /dev/null +++ b/work_docs/module_pw_cache_reuse_perf_compare_2026-06-06.md @@ -0,0 +1,225 @@ +# module_pw 缓存复用性能对比 + +日期:2026-06-06 +对比分支:`develop` vs `feat/cache-reuse` +对比范围:`source/source_basis/module_pw` + +## 1. 本次补充的计时点 + +为了定位缓存复用收益,本次在 `module_pw` 内部补了最小必要的 timer: + +- `PW_Basis::collect_local_pw` +- `PW_Basis::collect_local_pw_cache_hit` +- `PW_Basis::collect_local_pw_cache_build` +- `PW_Basis::collect_uniqgg` +- `PW_Basis::collect_uniqgg_cache_hit` +- `PW_Basis::collect_uniqgg_cache_build` +- `PW_Basis_K::collect_local_pw` +- `PW_Basis_K::collect_local_pw_cache_hit` +- `PW_Basis_K::collect_local_pw_build_gcar` +- `PW_Basis_K::collect_local_pw_build_gk2` + +说明: + +- `feat/cache-reuse` 上保留了命中/构建分支的细分 timer。 +- `develop` 基线 worktree 只补了等价的“构建路径” timer,用来做公平基准,不包含缓存实现本身。 + +## 2. 基准方法 + +### 2.1 构建方式 + +为避免 MPI 环境对 `timer.cpp` 的影响,最终性能数据使用串行构建: + +```bash +cmake -S /home/aunixt/abacus-develop -B /home/aunixt/abacus-develop/build-bench-serial \ + -DENABLE_MPI=OFF -DUSE_OPENMP=OFF -DUSE_ELPA=OFF -DBUILD_TESTING=OFF +cmake --build /home/aunixt/abacus-develop/build-bench-serial --target MODULE_PW_cache_bench +``` + +`develop` 基线在独立 worktree 中执行同样流程: + +```bash +git -C /home/aunixt/abacus-develop worktree add /home/aunixt/abacus-develop-develop develop +cmake -S /home/aunixt/abacus-develop-develop -B /home/aunixt/abacus-develop-develop/build-bench-serial \ + -DENABLE_MPI=OFF -DUSE_OPENMP=OFF -DUSE_ELPA=OFF -DBUILD_TESTING=OFF +cmake --build /home/aunixt/abacus-develop-develop/build-bench-serial --target MODULE_PW_cache_bench +``` + +### 2.2 基准程序 + +新增基准程序:`MODULE_PW_cache_bench` + +测试内容: + +- `PW_Basis.setuptransform` +- `PW_Basis.collect_local_pw` 首次调用 +- `PW_Basis.collect_local_pw` 重复 2000 次 +- `PW_Basis.collect_uniqgg` 首次调用 +- `PW_Basis.collect_uniqgg` 重复 2000 次 +- `PW_Basis_K.setuptransform` +- `PW_Basis_K.collect_local_pw` 首次调用 +- `PW_Basis_K.collect_local_pw` 重复 2000 次 +- `PW_Basis_K.collect_local_pw(1.0, 0.5, 0.2)` 重复 2000 次 + +统计口径: + +- 外层 wall time:基准程序直接测量 +- 内层 timer:`ModuleBase::timer` 累积结果 +- 每个分支各跑 3 次,结论使用中位数,避免单次抖动 + +## 3. 结果汇总 + +### 3.1 关键结论 + +1. `setuptransform` 基本无变化,说明优化没有破坏初始化主路径。 +2. `PW_Basis.collect_local_pw` 的重复调用从持续重建,变成几乎纯命中路径,中位数加速约 `255.5x`。 +3. `PW_Basis.collect_uniqgg` 的重复调用收益最大,中位数加速约 `2284.4x`。 +4. `PW_Basis_K.collect_local_pw` 的重复调用中位数加速约 `342.7x`。 +5. 即使传入新 `erf` 参数导致 `gk2` 需要重建,`PW_Basis_K` 仍然复用了 `gcar`,该路径中位数加速约 `463.0x`。 + +### 3.2 中位数对比表 + +单位:秒 + +| 场景 | develop | feat/cache-reuse | 提升 | +| --- | ---: | ---: | ---: | +| `PW_Basis.setuptransform.wall` | 0.001487681 | 0.001502684 | 0.99x | +| `PW_Basis.collect_local_pw.first.wall` | 0.000117131 | 0.000124510 | 0.94x | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.088178164 | 0.000345102 | 255.52x | +| `PW_Basis.collect_uniqgg.first.wall` | 0.000440547 | 0.000401324 | 1.10x | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.754634649 | 0.000330313 | 2284.40x | +| `PW_Basis_K.setuptransform.wall` | 0.000259432 | 0.000217700 | 1.19x | +| `PW_Basis_K.collect_local_pw.first.wall` | 0.000134854 | 0.000121473 | 1.11x | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.109138201 | 0.000318489 | 342.68x | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.193014060 | 0.000416850 | 463.04x | + +## 4. 原始样本 + +### 4.1 feat/cache-reuse + +#### Run 1 + +| 指标 | 数值 | +| --- | ---: | +| `PW_Basis.setuptransform.wall` | 0.001464156 | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.000462526 | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.000362283 | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.000318489 | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.000416850 | + +#### Run 2 + +| 指标 | 数值 | +| --- | ---: | +| `PW_Basis.setuptransform.wall` | 0.001733625 | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.000345102 | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.000330313 | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.000373748 | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.000529150 | + +#### Run 3 + +| 指标 | 数值 | +| --- | ---: | +| `PW_Basis.setuptransform.wall` | 0.001502684 | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.000314556 | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.000287104 | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.000317902 | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.000416157 | + +### 4.2 develop + +#### Run 1 + +| 指标 | 数值 | +| --- | ---: | +| `PW_Basis.setuptransform.wall` | 0.001473400 | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.086802023 | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.754634649 | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.109138201 | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.193014060 | + +#### Run 2 + +| 指标 | 数值 | +| --- | ---: | +| `PW_Basis.setuptransform.wall` | 0.001487681 | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.088178164 | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.745032054 | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.104775062 | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.176847280 | + +#### Run 3 + +| 指标 | 数值 | +| --- | ---: | +| `PW_Basis.setuptransform.wall` | 0.001664659 | +| `PW_Basis.collect_local_pw.repeat.wall` | 0.105299161 | +| `PW_Basis.collect_uniqgg.repeat.wall` | 0.965178448 | +| `PW_Basis_K.collect_local_pw.repeat.wall` | 0.134283688 | +| `PW_Basis_K.collect_local_pw.gk2_rebuild.wall` | 0.266846180 | + +## 5. 内部 timer 观察 + +### 5.1 feat/cache-reuse + +代表性现象: + +- `timer.PW_Basis.collect_local_pw_cache_hit.calls = 2000` +- `timer.PW_Basis.collect_uniqgg_cache_hit.calls = 2000` +- `timer.PW_Basis_K.collect_local_pw_cache_hit.calls = 3999` +- `timer.PW_Basis_K.collect_local_pw_build_gcar.calls = 1` +- `timer.PW_Basis_K.collect_local_pw_build_gk2.calls = 2` + +解释: + +- `PW_Basis` 的两条缓存路径在首轮构建后,后续 2000 次全部命中。 +- `PW_Basis_K` 在默认参数重复调用时,只有首轮需要构建。 +- 改变 `erf` 参数后,`gk2` 需要重新构建,但 `gcar` 仍然保持复用。 + +### 5.2 develop + +代表性现象: + +- `timer.PW_Basis.collect_local_pw_cache_build.calls = 2001` +- `timer.PW_Basis.collect_uniqgg_cache_build.calls = 2001` +- `timer.PW_Basis_K.collect_local_pw_build_gcar.calls = 4001` +- `timer.PW_Basis_K.collect_local_pw_build_gk2.calls = 4001` + +解释: + +- 基线分支每次调用都走完整构建,没有任何命中路径。 +- 这与重复调用 wall time 的数量级差异完全一致。 + +## 6. 验证记录 + +在 `feat/cache-reuse` 上新增并通过的回归测试: + +- `PWBasisTEST.CacheCollectionRecordsTimers` +- `PWBasisKTEST.CollectLocalPWRecordsTimers` + +测试命令: + +```bash +cd /home/aunixt/abacus-develop/build-tests-mpi/source/source_basis/module_pw/test_serial +./MODULE_PW_basis_pw_serial --gtest_filter=PWBasisTEST.CacheCollectionRecordsTimers +./MODULE_PW_basis_pw_k_serial --gtest_filter=PWBasisKTEST.CollectLocalPWRecordsTimers +``` + +结果:两项均通过。 + +## 7. 最终结论 + +`feat/cache-reuse` 在 `module_pw` 中的缓存复用优化是有效且收益非常显著的,主要收益集中在重复调用路径: + +- `PW_Basis.collect_local_pw` +- `PW_Basis.collect_uniqgg` +- `PW_Basis_K.collect_local_pw` + +其中最关键的是: + +- `develop` 重复调用仍然持续分配并重建数据。 +- `feat/cache-reuse` 已经将这部分开销压缩到首轮构建,后续主要变成 cache hit。 +- `PW_Basis_K` 在参数部分变化时还能保留 `gcar` 复用,说明缓存粒度设计是合理的。 + +如果后续需要,我建议直接把这份基准保留在仓库里,后面可以继续扩展成 CI 可重复的 micro-benchmark。