Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions source/source_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ add_library(
${objects}
)

add_executable(
MODULE_PW_cache_bench
test_serial/pw_cache_bench.cpp
)

target_link_libraries(
MODULE_PW_cache_bench
parameter
${math_libs}
planewave
device
base
Threads::Threads
)

if(USE_OPENMP)
target_link_libraries(MODULE_PW_cache_bench OpenMP::OpenMP_CXX)
endif()

if (USE_DSP)
target_link_libraries(planewave PRIVATE
${MTBLAS_FFT_DIR}/libmtblas/lib/libmtfft.a)
Expand Down
226 changes: 184 additions & 42 deletions source/source_basis/module_pw/pw_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "source_base/timer.h"
#include "source_base/global_function.h"

#include <vector>

namespace ModulePW
{
Expand All @@ -28,17 +29,13 @@ PW_Basis:: ~PW_Basis()
delete[] fftixy2ip;
delete[] nst_per;
delete[] npw_per;
delete[] gdirect;
delete[] gcar;
delete[] gg;
delete[] startz;
delete[] numz;
delete[] numg;
delete[] numr;
delete[] startg;
delete[] startr;
delete[] ig2igg;
delete[] gg_uniq;
this->clear_owned_cache();
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu")
{
Expand All @@ -48,6 +45,91 @@ PW_Basis:: ~PW_Basis()
#endif
}

void PW_Basis::clear_owned_cache()
{
std::lock_guard<std::mutex> guard(this->cache_mutex);
this->invalidate_cache_unlocked();
}

PW_Basis::CacheStats PW_Basis::get_cache_stats() const
{
std::lock_guard<std::mutex> guard(this->cache_mutex);
return this->get_cache_stats_unlocked();
}

PW_Basis::CacheStats PW_Basis::get_cache_stats_unlocked() const
{
CacheStats stats;
stats.local_pw_hits = this->local_pw_cache_hits.load();
stats.local_pw_misses = this->local_pw_cache_misses.load();
stats.uniqgg_hits = this->uniqgg_cache_hits.load();
stats.uniqgg_misses = this->uniqgg_cache_misses.load();
const bool has_local_pw_cache = this->local_pw_cache_valid.load()
&& this->npw > 0
&& this->gg != nullptr
&& this->gdirect != nullptr
&& this->gcar != nullptr;
const bool has_uniqgg_cache = this->uniqgg_cache_valid.load()
&& this->ngg > 0
&& this->ig2igg != nullptr
&& this->gg_uniq != nullptr;
if (has_local_pw_cache)
{
stats.cache_bytes += sizeof(double) * this->npw;
stats.cache_bytes += sizeof(ModuleBase::Vector3<double>) * this->npw * 2;
}
if (has_uniqgg_cache)
{
stats.cache_bytes += sizeof(int) * this->npw;
stats.cache_bytes += sizeof(double) * this->ngg;
}
return stats;
}

void PW_Basis::reset_cache_stats()
{
this->local_pw_cache_hits.store(0);
this->local_pw_cache_misses.store(0);
this->uniqgg_cache_hits.store(0);
this->uniqgg_cache_misses.store(0);
}

PW_Basis::CacheSignature PW_Basis::make_cache_signature() const
{
CacheSignature signature;
signature.lat0 = this->lat0;
signature.tpiba = this->tpiba;
signature.tpiba2 = this->tpiba2;
signature.nx = this->nx;
signature.ny = this->ny;
signature.nz = this->nz;
signature.fftnx = this->fftnx;
signature.fftny = this->fftny;
signature.fftnz = this->fftnz;
signature.npw = this->npw;
signature.G = this->G;
signature.GT = this->GT;
signature.GGT = this->GGT;
return signature;
}

bool PW_Basis::cache_signature_matches(const CacheSignature& signature) const
{
return signature.lat0 == this->lat0
&& signature.tpiba == this->tpiba
&& signature.tpiba2 == this->tpiba2
&& signature.nx == this->nx
&& signature.ny == this->ny
&& signature.nz == this->nz
&& signature.fftnx == this->fftnx
&& signature.fftny == this->fftny
&& signature.fftnz == this->fftnz
&& signature.npw == this->npw
&& std::memcmp(&signature.G, &this->G, sizeof(ModuleBase::Matrix3)) == 0
&& std::memcmp(&signature.GT, &this->GT, sizeof(ModuleBase::Matrix3)) == 0
&& std::memcmp(&signature.GGT, &this->GGT, sizeof(ModuleBase::Matrix3)) == 0;
}

///
/// distribute plane wave basis and real-space grids to different processors
/// set up maps for fft and create arrays for MPI_Alltoall
Expand Down Expand Up @@ -138,10 +220,33 @@ void PW_Basis::collect_local_pw()
{
return;
}
ModuleBase::timer::start(this->classname, "collect_local_pw");
std::lock_guard<std::mutex> guard(this->cache_mutex);
if (this->local_pw_cache_valid.load()
&& this->cache_signature_matches(this->local_pw_cache_signature))
{
ModuleBase::timer::start(this->classname, "collect_local_pw_cache_hit");
this->local_pw_cache_hits.fetch_add(1);
ModuleBase::timer::end(this->classname, "collect_local_pw_cache_hit");
ModuleBase::timer::end(this->classname, "collect_local_pw");
return;
}
ModuleBase::timer::start(this->classname, "collect_local_pw_cache_build");
this->local_pw_cache_misses.fetch_add(1);
this->ig_gge0 = -1;
delete[] this->gg; this->gg = new double[this->npw];
delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3<double>[this->npw];
delete[] this->gcar; this->gcar = new ModuleBase::Vector3<double>[this->npw];
this->gg_cache_storage.reset(new double[this->npw]);
this->gdirect_cache_storage.reset(new ModuleBase::Vector3<double>[this->npw]);
this->gcar_cache_storage.reset(new ModuleBase::Vector3<double>[this->npw]);
this->gg = this->gg_cache_storage.get();
this->gdirect = this->gdirect_cache_storage.get();
this->gcar = this->gcar_cache_storage.get();
// Unique-G data depends on gg, so rebuilding local G data invalidates it.
this->uniqgg_cache_valid.store(false);
this->ig2igg_cache_storage.reset();
this->gg_uniq_cache_storage.reset();
this->ig2igg = nullptr;
this->gg_uniq = nullptr;
this->ngg = 0;

ModuleBase::Vector3<double> f;
int gamma_num = 0;
Expand Down Expand Up @@ -182,6 +287,10 @@ void PW_Basis::collect_local_pw()
}
}
}
this->local_pw_cache_valid.store(true);
this->local_pw_cache_signature = this->make_cache_signature();
ModuleBase::timer::end(this->classname, "collect_local_pw_cache_build");
ModuleBase::timer::end(this->classname, "collect_local_pw");
return;
}

Expand All @@ -196,45 +305,74 @@ void PW_Basis::collect_uniqgg()
{
return;
}
ModuleBase::timer::start(this->classname, "collect_uniqgg");
std::lock_guard<std::mutex> guard(this->cache_mutex);
if (this->uniqgg_cache_valid.load()
&& this->cache_signature_matches(this->uniqgg_cache_signature))
{
ModuleBase::timer::start(this->classname, "collect_uniqgg_cache_hit");
this->uniqgg_cache_hits.fetch_add(1);
ModuleBase::timer::end(this->classname, "collect_uniqgg_cache_hit");
ModuleBase::timer::end(this->classname, "collect_uniqgg");
return;
}
ModuleBase::timer::start(this->classname, "collect_uniqgg_cache_build");
this->uniqgg_cache_misses.fetch_add(1);
this->ig_gge0 = -1;
delete[] this->ig2igg; this->ig2igg = new int [this->npw];
this->ig2igg_cache_storage.reset(new int[this->npw]);
this->ig2igg = this->ig2igg_cache_storage.get();

int *sortindex = new int [this->npw];//Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves
double *tmpgg = new double [this->npw];//Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond
double *tmpgg2 = new double [this->npw];//ranking the plane waves by energy size and removing the duplicates
ModuleBase::Vector3<double> f;
for(int ig = 0 ; ig < this-> npw ; ++ig)
std::vector<int> sortindex(this->npw); // Reconstruct the plane-wave index mapping after sorting by energy.
std::vector<double> tmpgg(this->npw);
std::vector<double> tmpgg2(this->npw);
// Reuse gg when collect_local_pw has already built the same G^2 values.
if (this->local_pw_cache_valid.load() && this->gg != nullptr)
{
int isz = this->ig2isz[ig];
int iz = isz % this->nz;
int is = isz / this->nz;
int ixy = this->is2fftixy[is];
int ix = ixy / this->fftny;
int iy = ixy % this->fftny;
if (ix >= int(this->nx/2) + 1)
{
ix -= this->nx;
}
if (iy >= int(this->ny/2) + 1)
{
iy -= this->ny;
}
if (iz >= int(this->nz/2) + 1)
for(int ig = 0 ; ig < this-> npw ; ++ig)
{
iz -= this->nz;
tmpgg[ig] = this->gg[ig];
if(tmpgg[ig] < 1e-8)
{
this->ig_gge0 = ig;
}
}
f.x = ix;
f.y = iy;
f.z = iz;
tmpgg[ig] = f * (this->GGT * f);
if(tmpgg[ig] < 1e-8)
}
else
{
ModuleBase::Vector3<double> f;
for(int ig = 0 ; ig < this-> npw ; ++ig)
{
this->ig_gge0 = ig;
int isz = this->ig2isz[ig];
int iz = isz % this->nz;
int is = isz / this->nz;
int ixy = this->is2fftixy[is];
int ix = ixy / this->fftny;
int iy = ixy % this->fftny;
if (ix >= int(this->nx/2) + 1)
{
ix -= this->nx;
}
if (iy >= int(this->ny/2) + 1)
{
iy -= this->ny;
}
if (iz >= int(this->nz/2) + 1)
{
iz -= this->nz;
}
f.x = ix;
f.y = iy;
f.z = iz;
tmpgg[ig] = f * (this->GGT * f);
if(tmpgg[ig] < 1e-8)
{
this->ig_gge0 = ig;
}
}
}

ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw);
ModuleBase::heapsort(this->npw, tmpgg, sortindex);
ModuleBase::GlobalFunc::ZEROS(sortindex.data(), this->npw);
ModuleBase::heapsort(this->npw, tmpgg.data(), sortindex.data());


int igg = 0;
Expand All @@ -261,14 +399,16 @@ void PW_Basis::collect_uniqgg()
}
tmpgg2[igg] = avg_gg / double(avg_n);
this->ngg = igg + 1;
delete[] this->gg_uniq; this->gg_uniq = new double [this->ngg];
this->gg_uniq_cache_storage.reset(new double[this->ngg]);
this->gg_uniq = this->gg_uniq_cache_storage.get();
for(int igg = 0 ; igg < this->ngg ; ++igg)
{
gg_uniq[igg] = tmpgg2[igg];
}
delete[] sortindex;
delete[] tmpgg;
delete[] tmpgg2;
this->uniqgg_cache_valid.store(true);
this->uniqgg_cache_signature = this->make_cache_signature();
ModuleBase::timer::end(this->classname, "collect_uniqgg_cache_build");
ModuleBase::timer::end(this->classname, "collect_uniqgg");
}

void PW_Basis::getfftixy2is(int * fftixy2is) const
Expand All @@ -295,10 +435,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is) const

void PW_Basis::set_device(std::string device_) {
this->device = std::move(device_);
this->invalidate_cache();
}

void PW_Basis::set_precision(std::string precision_) {
this->precision = std::move(precision_);
this->invalidate_cache();
}

}
Loading
Loading