Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions source/source_basis/module_pw/pw_distributeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ void PW_Basis::count_pw_st(
{
ModuleBase::GlobalFunc::ZEROS(st_length2D, this->fftnxy);
ModuleBase::GlobalFunc::ZEROS(st_bottom2D, this->fftnxy);
ModuleBase::Vector3<double> f;

// determine the scaning area along x-direct, if gamma-only && xprime, only positive axis is used.
int ix_end = int(this->nx / 2) + 1;
Expand Down Expand Up @@ -85,10 +84,18 @@ void PW_Basis::count_pw_st(
}
}

this->liy = this->riy = 0;
this->lix = this->rix = 0;
this->npwtot = 0;
this->nstot = 0;
int npwtot_local = 0;
int nstot_local = 0;
int lix_local = 0, rix_local = 0;
int liy_local = 0, riy_local = 0;

#ifdef _OPENMP
#pragma omp parallel for collapse(2) \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you compare the performance with collapse(1)? In this kind of loop nest, collapse(1) is often faster than collapse(2) when using the same level of parallelism.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, could you compare the single-thread performance with and without OpenMP? I think collapse(2) might still be much slower even with one thread.

shared(st_length2D, st_bottom2D) \
reduction(+: npwtot_local, nstot_local) \
reduction(min: rix_local, riy_local) \
reduction(max: lix_local, liy_local)
#endif
for (int ix = ix_start; ix <= ix_end; ++ix)
{
for (int iy = iy_start; iy <= iy_end; ++iy)
Expand All @@ -100,44 +107,53 @@ void PW_Basis::count_pw_st(
// so that its index in st_length and st_bottom is 9 * 10 + 2 = 92.
int x = ix;
int y = iy;
if (x < 0) { x += this->nx;
}
if (y < 0) { y += this->ny;
}
if (x < 0) { x += this->nx; }
if (y < 0) { y += this->ny; }
int index = x * this->fftny + y;

int length = 0; // number of planewave on stick (x, y).

ModuleBase::Vector3<double> f;
f.x = ix;
f.y = iy;

for (int iz = iz_start; iz <= iz_end; ++iz)
{
f.x = ix;
f.y = iy;
f.z = iz;
double modulus = f * (this->GGT * f);
if (modulus <= this->ggecut || this->full_pw)
{
if (length == 0) { st_bottom2D[index] = iz; // length == 0 means this point is the bottom of stick (x, y).
}
++this->npwtot;
if (length == 0)
{
st_bottom2D[index] = iz; // length == 0 means this point is the bottom of stick (x, y).
}
++npwtot_local;
++length;
if(iy < this->riy) { this->riy = iy;
}
if(iy > this->liy) { this->liy = iy;
}
if(ix < this->rix) { this->rix = ix;
}
if(ix > this->lix) { this->lix = ix;
}
if(ix > lix_local) { lix_local = ix; }
if(ix < rix_local) { rix_local = ix; }
if(iy > liy_local) { liy_local = iy; }
if(iy < riy_local) { riy_local = iy; }
}
}
if (length > 0)
{
st_length2D[index] = length;
++this->nstot;
++nstot_local;
}
}
}
riy += this->ny;
rix += this->nx;

if (npwtot_local == 0) // no planewave
{
lix_local = rix_local = liy_local = riy_local = 0;
}

this->npwtot = npwtot_local;
this->nstot = nstot_local;
this->lix = lix_local;
this->rix = rix_local + this->nx;
this->liy = liy_local;
this->riy = riy_local + this->ny;
return;
}

Expand Down
Loading