diff --git a/source/source_cell/cal_atoms_info.h b/source/source_cell/cal_atoms_info.h index e778aa69524..7172a5108b9 100644 --- a/source/source_cell/cal_atoms_info.h +++ b/source/source_cell/cal_atoms_info.h @@ -71,7 +71,7 @@ class CalAtomsInfo // calculate the number of nbands_local para.sys.nbands_l = para.inp.nbands; - if (para.inp.ks_solver == "bpcg") // only bpcg support band parallel + if (para.inp.ks_solver == "bpcg" || para.inp.ks_solver == "lobpcg") { para.sys.nbands_l = para.inp.nbands / para.inp.bndpar; if (GlobalV::MY_BNDGROUP < para.inp.nbands % para.inp.bndpar) @@ -80,7 +80,9 @@ class CalAtomsInfo } } // temporary code - if (GlobalV::MY_BNDGROUP == 0 || para.inp.ks_solver == "bpcg") + if (GlobalV::MY_BNDGROUP == 0 + || para.inp.ks_solver == "bpcg" + || para.inp.ks_solver == "lobpcg") { para.sys.ks_run = true; } diff --git a/source/source_estate/elecstate_print.cpp b/source/source_estate/elecstate_print.cpp index 517c002dd6d..df2bb9b4821 100644 --- a/source/source_estate/elecstate_print.cpp +++ b/source/source_estate/elecstate_print.cpp @@ -58,6 +58,7 @@ void print_scf_iterinfo(const std::string& ks_solver, {"scalapack_gvx", "GV"}, {"cusolver", "CU"}, {"bpcg", "BP"}, + {"lobpcg", "LB"}, {"pexsi", "PE"}, {"cusolvermp", "CM"}, {"sdft", "CT"}}; // CT = Chebyshev Trace, for pure SDFT (nbands=0) where no H diagonalization is performed diff --git a/source/source_hsolver/CMakeLists.txt b/source/source_hsolver/CMakeLists.txt index b115d6d4cd2..2e30217a6e5 100644 --- a/source/source_hsolver/CMakeLists.txt +++ b/source/source_hsolver/CMakeLists.txt @@ -4,6 +4,7 @@ list(APPEND objects diago_david.cpp diago_dav_subspace.cpp diago_bpcg.cpp + diago_lobpcg.cpp para_linear_transform.cpp hsolver_pw.cpp hsolver_lcaopw.cpp diff --git a/source/source_hsolver/diago_lobpcg.cpp b/source/source_hsolver/diago_lobpcg.cpp new file mode 100644 index 00000000000..ea9dcb47cfe --- /dev/null +++ b/source/source_hsolver/diago_lobpcg.cpp @@ -0,0 +1,2799 @@ +#include "source_hsolver/diago_lobpcg.h" + +#include "diago_iter_assist.h" +#include "source_base/global_function.h" +#include "source_base/global_variable.h" +#include "source_base/kernels/math_kernel_op.h" +#include "source_base/parallel_comm.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace hsolver { + +using LobpcgClock = std::chrono::steady_clock; + +// ============================================================================ +// Band-major explicit-loop helpers used by serial generalized fallback paths. +// +// Psi is stored band-major: psi_data[ib * n_basis + ig], +// shape [n_band_l, n_basis]. ig >= n_dim must be zero-padded. +// +// Subspace matrices (C, V) are column-major for direct LAPACK use: +// C[col * ld + row] = C[j * nb + i] (nb = leading dimension). +// ============================================================================ + +/// C(i,j) = sum_ig conj( A(i,ig) ) * B(j,ig) standard inner-product +template +static void inner_product_loop(int nb, int lda, int nvalid, + T alpha, const T* A, const T* B, + T beta, T* C) +{ + for (int j = 0; j < nb; ++j) { + for (int i = 0; i < nb; ++i) { + T sum = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + sum += std::conj(A[i * lda + ig]) * B[j * lda + ig]; + } + C[j * nb + i] = alpha * sum + beta * C[j * nb + i]; + } + } +} + +/// newRow_i = sum_k V(k,i) * oldRow_k +/// V col-major: V(k,i) = V[i * nb + k] +template +static void rotate_loop(int nb, int lda, int nvalid, + int ldv, T alpha, const T* V, const T* A, + T beta, T* C) +{ + for (int i = 0; i < nb; ++i) { + for (int ig = 0; ig < nvalid; ++ig) { + T sum = static_cast(0.0); + for (int k = 0; k < nb; ++k) + sum += V[i * ldv + k] * A[k * lda + ig]; + C[i * lda + ig] = (beta == static_cast(0.0)) + ? alpha * sum + : alpha * sum + beta * C[i * lda + ig]; + } + for (int ig = nvalid; ig < lda; ++ig) { + C[i * lda + ig] = static_cast(0.0); + } + } +} + +// ============================================================================ +// File-static helpers +// ============================================================================ + +template +static void mirror_lower(T* mat, int ld, int active_sub) +{ + for (int c = 0; c < active_sub; c++) + for (int r = c + 1; r < active_sub; r++) + mat[c * ld + r] = std::conj(mat[r * ld + c]); +} + +template +static void clean_hermitian_diag(T* mat, int ld, int active_sub) +{ + using Real = typename GetTypeReal::type; + for (int i = 0; i < active_sub; i++) { + int idx = i * ld + i; + mat[idx] = T(std::real(mat[idx]), static_cast(0)); + } +} + +template +static void hermitize(T* mat, int ld, int active_sub) +{ + clean_hermitian_diag(mat, ld, active_sub); + for (int jc = 0; jc < active_sub; ++jc) { + for (int ir = jc + 1; ir < active_sub; ++ir) { + const T avg = static_cast(0.5) + * (mat[jc * ld + ir] + + std::conj(mat[ir * ld + jc])); + mat[jc * ld + ir] = avg; + mat[ir * ld + jc] = std::conj(avg); + } + } +} + +template +static bool finite_real_block(const Real* data, int n) +{ + for (int i = 0; i < n; ++i) { + if (!std::isfinite(data[i])) { + return false; + } + } + return true; +} + +template +static bool finite_scalar_block(const T* data, int n) +{ + for (int i = 0; i < n; ++i) { + if (!std::isfinite(std::real(data[i])) || !std::isfinite(std::imag(data[i]))) { + return false; + } + } + return true; +} + +template +static std::string hermitian_matrix_diagnostics(const char* name, const T* mat, int ld, int n) +{ + using Real = typename GetTypeReal::type; + int nonfinite = 0; + Real max_abs = static_cast(0.0); + Real max_antiherm = static_cast(0.0); + Real min_diag = std::numeric_limits::max(); + Real max_diag = -std::numeric_limits::max(); + + for (int j = 0; j < n; ++j) { + const T diag = mat[j * ld + j]; + if (!std::isfinite(std::real(diag)) || !std::isfinite(std::imag(diag))) { + ++nonfinite; + } else { + min_diag = std::min(min_diag, static_cast(std::real(diag))); + max_diag = std::max(max_diag, static_cast(std::real(diag))); + } + for (int i = 0; i < n; ++i) { + const T a = mat[j * ld + i]; + if (!std::isfinite(std::real(a)) || !std::isfinite(std::imag(a))) { + ++nonfinite; + continue; + } + max_abs = std::max(max_abs, static_cast(std::abs(a))); + const T b = mat[i * ld + j]; + if (std::isfinite(std::real(b)) && std::isfinite(std::imag(b))) { + max_antiherm = std::max(max_antiherm, static_cast(std::abs(a - std::conj(b)))); + } + } + } + + std::ostringstream oss; + oss << name + << " nonfinite=" << nonfinite + << " max_abs=" << std::setprecision(12) << max_abs + << " max_antiherm=" << max_antiherm + << " diag_min=" << (min_diag == std::numeric_limits::max() ? static_cast(0.0) : min_diag) + << " diag_max=" << (max_diag == -std::numeric_limits::max() ? static_cast(0.0) : max_diag); + return oss.str(); +} + +template +static std::string s_overlap_diagnostics(const T* ssub, int ld, int n) +{ + using Real = typename GetTypeReal::type; + Real max_dev = static_cast(0.0); + Real max_offdiag = static_cast(0.0); + Real min_diag = std::numeric_limits::max(); + Real max_diag = -std::numeric_limits::max(); + int nonfinite = 0; + + for (int j = 0; j < n; ++j) { + for (int i = 0; i < n; ++i) { + const T value = ssub[j * ld + i]; + if (!std::isfinite(std::real(value)) || !std::isfinite(std::imag(value))) { + ++nonfinite; + continue; + } + const Real abs_value = static_cast(std::abs(value)); + if (i == j) { + const Real diag = static_cast(std::real(value)); + min_diag = std::min(min_diag, diag); + max_diag = std::max(max_diag, diag); + max_dev = std::max(max_dev, static_cast(std::abs(diag - static_cast(1.0)))); + } else { + max_offdiag = std::max(max_offdiag, abs_value); + max_dev = std::max(max_dev, abs_value); + } + } + } + + std::ostringstream oss; + oss << "S_orth nonfinite=" << nonfinite + << " max_abs(S-I)=" << std::setprecision(12) << max_dev + << " max_offdiag=" << max_offdiag + << " diag_min=" << (min_diag == std::numeric_limits::max() ? static_cast(0.0) : min_diag) + << " diag_max=" << (max_diag == -std::numeric_limits::max() ? static_cast(0.0) : max_diag); + return oss.str(); +} + +template +struct SubspaceSpdCheck +{ + bool ok = false; + Real min_eval = static_cast(0.0); + Real max_eval = static_cast(0.0); + Real cond = std::numeric_limits::infinity(); + Real floor = static_cast(0.0); + std::string error; +}; + +template +static SubspaceSpdCheck::type> +check_subspace_spd(const T* ssub, int ld, int n) +{ + using Real = typename GetTypeReal::type; + SubspaceSpdCheck result; + if (n <= 0) { + result.error = "empty matrix"; + return result; + } + + std::vector smat(n * n, static_cast(0.0)); + std::vector eval(n, static_cast(0.0)); + for (int jc = 0; jc < n; ++jc) { + for (int ir = 0; ir < n; ++ir) { + smat[jc * n + ir] = ssub[jc * ld + ir]; + } + } + + try { + ct::kernels::lapack_heevd()(n, smat.data(), n, eval.data()); + } catch (const std::exception& e) { + result.error = e.what(); + return result; + } + + if (!finite_real_block(eval.data(), n)) { + result.error = "non-finite overlap eigenvalues"; + return result; + } + + result.min_eval = eval.front(); + result.max_eval = eval.back(); + for (int i = 0; i < n; ++i) { + result.min_eval = std::min(result.min_eval, eval[i]); + result.max_eval = std::max(result.max_eval, eval[i]); + } + + const Real eps = std::numeric_limits::epsilon(); + const Real cond_limit = std::is_same::value + ? static_cast(1.0e10) + : static_cast(1.0e5); + const Real scale = std::max(static_cast(1.0), std::abs(result.max_eval)); + result.floor = std::max(scale / cond_limit, + static_cast(100.0) * eps * scale); + + if (!std::isfinite(result.min_eval) || !std::isfinite(result.max_eval) + || result.max_eval <= static_cast(0.0) + || result.min_eval <= result.floor) { + result.error = "ill-conditioned overlap"; + return result; + } + + result.cond = result.max_eval / result.min_eval; + if (!std::isfinite(result.cond) || result.cond > cond_limit) { + result.error = "overlap condition exceeds limit"; + result.ok = false; + return result; + } + + result.ok = true; + return result; +} + +template +static std::string subspace_spd_diagnostics(const SubspaceSpdCheck& check) +{ + std::ostringstream oss; + oss << "S_sub spectrum min=" << std::setprecision(12) << check.min_eval + << " max=" << check.max_eval + << " cond=" << check.cond + << " floor=" << check.floor; + if (!check.error.empty()) { + oss << " error=" << check.error; + } + return oss.str(); +} + +template +static bool scale_subspace_by_overlap_diag(T* hsub, + T* ssub, + int ld, + int n, + std::vector::type>& inv_norm, + std::string& error) +{ + using Real = typename GetTypeReal::type; + inv_norm.assign(n, static_cast(0.0)); + const Real diag_floor = std::numeric_limits::min(); + + for (int i = 0; i < n; ++i) { + const T diag = ssub[i * ld + i]; + const Real diag_real = static_cast(std::real(diag)); + if (!std::isfinite(diag_real) || diag_real <= diag_floor) { + std::ostringstream oss; + oss << "invalid overlap diagonal at column " << i + << ": " << std::setprecision(12) << diag_real; + error = oss.str(); + return false; + } + inv_norm[i] = static_cast(1.0) / std::sqrt(diag_real); + } + + for (int jc = 0; jc < n; ++jc) { + for (int ir = 0; ir < n; ++ir) { + const Real scale = inv_norm[ir] * inv_norm[jc]; + hsub[jc * ld + ir] *= scale; + ssub[jc * ld + ir] *= scale; + } + } + hermitize(hsub, ld, n); + hermitize(ssub, ld, n); + return true; +} + +template +static bool finite_vector_block(const T* data, int nvec, int lda, int nvalid) +{ + for (int ib = 0; ib < nvec; ++ib) { + for (int ig = 0; ig < nvalid; ++ig) { + const T value = data[ib * lda + ig]; + if (!std::isfinite(std::real(value)) || !std::isfinite(std::imag(value))) { + return false; + } + } + } + return true; +} + +template +static std::string vector_block_diagnostics(const char* name, + const T* data, + int nvec, + int lda, + int nvalid) +{ + using Real = typename GetTypeReal::type; + int nonfinite = 0; + int nonfinite_bands = 0; + int first_band = -1; + int first_ig = -1; + Real min_norm = std::numeric_limits::max(); + Real max_norm = static_cast(0.0); + Real max_abs = static_cast(0.0); + + for (int ib = 0; ib < nvec; ++ib) { + bool band_bad = false; + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + const T value = data[ib * lda + ig]; + if (!std::isfinite(std::real(value)) || !std::isfinite(std::imag(value))) { + ++nonfinite; + band_bad = true; + if (first_band < 0) { + first_band = ib; + first_ig = ig; + } + continue; + } + const Real abs_value = static_cast(std::abs(value)); + max_abs = std::max(max_abs, abs_value); + norm2 += std::norm(value); + } + if (band_bad) { + ++nonfinite_bands; + } else { + const Real norm = std::sqrt(norm2); + min_norm = std::min(min_norm, norm); + max_norm = std::max(max_norm, norm); + } + } + + std::ostringstream oss; + oss << name + << " nonfinite=" << nonfinite + << " nonfinite_bands=" << nonfinite_bands + << " first=(" << first_band << "," << first_ig << ")" + << " finite_norm_min=" + << (min_norm == std::numeric_limits::max() ? static_cast(0.0) : min_norm) + << " finite_norm_max=" << max_norm + << " max_abs=" << max_abs; + return oss.str(); +} + +static void lobpcg_diag_log(const std::string& context, + const std::string& line1, + const std::string& line2, + const std::string& line3 = std::string()) +{ + std::ostringstream oss; + oss << " LOBPCG_DIAG " << context << '\n' + << " " << line1 << '\n' + << " " << line2 << '\n'; + if (!line3.empty()) { + oss << " " << line3 << '\n'; + } + if (GlobalV::ofs_running.good()) { + GlobalV::ofs_running << oss.str(); + GlobalV::ofs_running.flush(); + } +} + +template +void DiagoLobpcg::diag_log(const std::string& context, + const std::string& line1, + const std::string& line2, + const std::string& line3) const +{ + const std::string full_context = this->diag_context.empty() + ? context + : context + " [" + this->diag_context + "]"; + lobpcg_diag_log(full_context, line1, line2, line3); +} + +template +static bool append_orthonormal_block( + const int nvec, const int lda, const int nvalid, + const typename GetTypeReal::type thresh, + const T* block, const T* hblock, + std::vector& basis, std::vector& hbasis) +{ + using Real = typename GetTypeReal::type; + bool appended = false; + const Real thresh2 = thresh * thresh; + + for (int ib = 0; ib < nvec; ++ib) { + std::vector q(lda, static_cast(0.0)); + std::vector hq(lda, static_cast(0.0)); + for (int ig = 0; ig < nvalid; ++ig) { + q[ig] = block[ib * lda + ig]; + hq[ig] = hblock[ib * lda + ig]; + } + + const int nold = static_cast(basis.size() / lda); + for (int pass = 0; pass < 2; ++pass) { + for (int jq = 0; jq < nold; ++jq) { + T dot = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + dot += std::conj(basis[jq * lda + ig]) * q[ig]; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&dot, 1); +#endif + for (int ig = 0; ig < nvalid; ++ig) { + q[ig] -= dot * basis[jq * lda + ig]; + hq[ig] -= dot * hbasis[jq * lda + ig]; + } + } + } + + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::norm(q[ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!std::isfinite(norm2) || norm2 <= thresh2) { + continue; + } + + const Real inv_norm = static_cast(1.0) / std::sqrt(norm2); + for (int ig = 0; ig < lda; ++ig) { + basis.push_back(q[ig] * inv_norm); + hbasis.push_back(hq[ig] * inv_norm); + } + appended = true; + } + + return appended; +} + +template +static bool append_normalized_block( + const int nvec, const int lda, const int nvalid, + const typename GetTypeReal::type thresh, + const T* block, const T* hblock, + std::vector& basis, std::vector& hbasis) +{ + using Real = typename GetTypeReal::type; + bool appended = false; + const Real thresh2 = thresh * thresh; + + for (int ib = 0; ib < nvec; ++ib) { + const T* src = block + ib * lda; + const T* hsrc = hblock + ib * lda; + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::norm(src[ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!std::isfinite(norm2) || norm2 <= thresh2) { + continue; + } + + const Real inv_norm = static_cast(1.0) / std::sqrt(norm2); + for (int ig = 0; ig < nvalid; ++ig) { + basis.push_back(src[ig] * inv_norm); + hbasis.push_back(hsrc[ig] * inv_norm); + } + for (int ig = nvalid; ig < lda; ++ig) { + basis.push_back(static_cast(0.0)); + hbasis.push_back(static_cast(0.0)); + } + appended = true; + } + + return appended; +} + +template +static bool append_s_orthonormal_block( + const int nvec, const int lda, const int nvalid, + const typename GetTypeReal::type thresh, + const T* block, const T* hblock, const T* sblock, + std::vector& basis, std::vector& hbasis, std::vector& sbasis) +{ + using Real = typename GetTypeReal::type; + bool appended = false; + + for (int ib = 0; ib < nvec; ++ib) { + std::vector q(lda, static_cast(0.0)); + std::vector hq(lda, static_cast(0.0)); + std::vector sq(lda, static_cast(0.0)); + for (int ig = 0; ig < nvalid; ++ig) { + q[ig] = block[ib * lda + ig]; + hq[ig] = hblock[ib * lda + ig]; + sq[ig] = sblock[ib * lda + ig]; + } + + const int nold = static_cast(basis.size() / lda); + for (int pass = 0; pass < 2; ++pass) { + for (int jq = 0; jq < nold; ++jq) { + T dot = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + dot += std::conj(basis[jq * lda + ig]) * sq[ig]; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&dot, 1); +#endif + for (int ig = 0; ig < nvalid; ++ig) { + q[ig] -= dot * basis[jq * lda + ig]; + hq[ig] -= dot * hbasis[jq * lda + ig]; + sq[ig] -= dot * sbasis[jq * lda + ig]; + } + } + } + + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::real(std::conj(q[ig]) * sq[ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!std::isfinite(norm2) || norm2 <= thresh) { + continue; + } + + const Real inv_norm = static_cast(1.0) / std::sqrt(norm2); + for (int ig = 0; ig < lda; ++ig) { + basis.push_back(q[ig] * inv_norm); + hbasis.push_back(hq[ig] * inv_norm); + sbasis.push_back(sq[ig] * inv_norm); + } + appended = true; + } + + return appended; +} + +// ============================================================================ +// Constructor / Destructor +// ============================================================================ + +template +DiagoLobpcg::DiagoLobpcg(const Real* precondition) +{ + this->r_type = ct::DataTypeToEnum::value; + this->t_type = ct::DataTypeToEnum::value; + this->dev_type = ct::DeviceTypeToEnum::value; + this->h_prec_ptr = precondition; + this->one = &one_; + this->zero = &zero_; + this->neg_one = &neg_one_; +} + +template +DiagoLobpcg::~DiagoLobpcg() {} + +// ============================================================================ +// init_iter +// ============================================================================ + +template +void DiagoLobpcg::init_iter(int nband, int nband_l, + int nbasis, int ndim) +{ + this->n_band = nband; + this->n_band_l = nband_l; + this->n_basis = nbasis; + this->n_dim = ndim; + this->nsub = 3 * n_band; + this->has_pdir = false; + + this->eigen = ct::Tensor(r_type, dev_type, {this->n_band}); + this->sub_eigen = ct::Tensor(r_type, dev_type, {this->nsub}); + this->err_st = ct::Tensor(r_type, dev_type, {this->n_band_l}); + + this->hsub = ct::Tensor(t_type, dev_type, {this->nsub, this->nsub}); + this->ssub = ct::Tensor(t_type, dev_type, {this->nsub, this->nsub}); + this->tmp_hsub = ct::Tensor(t_type, dev_type, {this->n_band, this->n_band}); + this->tmp_ssub = ct::Tensor(t_type, dev_type, {this->n_band, this->n_band}); + + this->hpsi = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->spsi = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->grad = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->hgrad = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->sgrad = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->pdir = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->hpdir = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->spdir = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + + this->work = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->hwork = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->swork = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->pwork = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->hpwork = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + this->spwork = ct::Tensor(t_type, dev_type, {this->n_band_l, this->n_basis}); + + this->prec = ct::Tensor(r_type, dev_type, {this->n_basis}); + this->h_prec = ct::TensorMap( + (void*)this->h_prec_ptr, this->r_type, + ct::DeviceType::CpuDevice, {this->n_basis}); + +#ifdef __MPI + this->pmmcn.set_dimension(BP_WORLD, POOL_WORLD, + n_band_l, n_basis, + n_band_l, n_basis, + n_dim, n_band); + this->plintrans.set_dimension(n_dim, nband_l, n_band_l, n_basis, + BP_WORLD, false); +#else + this->pmmcn.set_dimension(n_band_l, n_basis, + n_band_l, n_basis, + n_dim, n_band); + this->plintrans.set_dimension(n_dim, nband_l, n_band_l, n_basis, + false); +#endif +} + +template +int DiagoLobpcg::local_band_start() const +{ +#ifdef __MPI + if (this->plintrans.nproc_col > 1) { + return this->plintrans.start_colB[GlobalV::MY_BNDGROUP]; + } +#endif + return 0; +} + +// ============================================================================ +// calc_prec +// ============================================================================ + +template +void DiagoLobpcg::calc_prec() +{ + syncmem_var_h2d_op()(this->prec.template data(), + this->h_prec.template data(), + this->n_basis); +} + +// ============================================================================ +// calc_hpsi_with_block / calc_spsi_with_block +// ============================================================================ + +template +void DiagoLobpcg::calc_hpsi_with_block( + const HPsiFunc& hpsi_func, T* psi_in, ct::Tensor& hpsi_out) +{ + hpsi_func(psi_in, hpsi_out.data(), this->n_basis, this->n_band_l); + if (!finite_vector_block(psi_in, this->n_band_l, this->n_basis, this->n_dim) + || !finite_vector_block(hpsi_out.data(), this->n_band_l, this->n_basis, this->n_dim)) { + this->diag_log("calc_hpsi_with_block non-finite", + vector_block_diagnostics("psi_in", + psi_in, + this->n_band_l, + this->n_basis, + this->n_dim), + vector_block_diagnostics("hpsi_out", + hpsi_out.data(), + this->n_band_l, + this->n_basis, + this->n_dim)); + throw std::runtime_error("LOBPCG hPsi produced non-finite values"); + } +} + +template +void DiagoLobpcg::calc_spsi_with_block( + const SPsiFunc& spsi_func, const T* psi_in, ct::Tensor& spsi_out) +{ + spsi_func(psi_in, spsi_out.data(), this->n_basis, this->n_band_l); + if (!finite_vector_block(psi_in, this->n_band_l, this->n_basis, this->n_dim) + || !finite_vector_block(spsi_out.data(), this->n_band_l, this->n_basis, this->n_dim)) { + this->diag_log("calc_spsi_with_block non-finite", + vector_block_diagnostics("psi_in", + psi_in, + this->n_band_l, + this->n_basis, + this->n_dim), + vector_block_diagnostics("spsi_out", + spsi_out.data(), + this->n_band_l, + this->n_basis, + this->n_dim)); + throw std::runtime_error("LOBPCG sPsi produced non-finite values"); + } +} + +template +void DiagoLobpcg::repair_initial_subspace_s( + const HPsiFunc& hpsi_func, const SPsiFunc& spsi_func) +{ + const int nb = this->n_band; + const int lda = this->n_basis; + const int nvalid = this->n_dim; + const Real eps = static_cast(100) + * std::numeric_limits::epsilon(); + + T* psi_d = this->psi.data(); + T* spsi_d = this->spsi.data(); + bool repaired = false; + + for (int ib = 0; ib < nb; ++ib) { + bool ready = false; + for (int attempt = -1; attempt < nvalid && !ready; ++attempt) { + if (attempt >= 0) { + repaired = true; + for (int ig = 0; ig < lda; ++ig) { + psi_d[ib * lda + ig] = static_cast(0.0); + spsi_d[ib * lda + ig] = static_cast(0.0); + } + psi_d[ib * lda + ((ib + attempt) % nvalid)] = static_cast(1.0); + spsi_func(psi_d + ib * lda, spsi_d + ib * lda, lda, 1); + } + + bool finite_vec = true; + for (int ig = 0; ig < nvalid; ++ig) { + const T qv = psi_d[ib * lda + ig]; + const T sqv = spsi_d[ib * lda + ig]; + if (!std::isfinite(std::real(qv)) || !std::isfinite(std::imag(qv)) + || !std::isfinite(std::real(sqv)) || !std::isfinite(std::imag(sqv))) { + finite_vec = false; + break; + } + } + if (!finite_vec) { + repaired = true; + continue; + } + + bool finite_projection = true; + for (int pass = 0; pass < 2; ++pass) { + for (int jb = 0; jb < ib; ++jb) { + T dot = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + dot += std::conj(psi_d[jb * lda + ig]) + * spsi_d[ib * lda + ig]; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&dot, 1); +#endif + if (!std::isfinite(std::real(dot)) || !std::isfinite(std::imag(dot))) { + finite_projection = false; + break; + } + for (int ig = 0; ig < nvalid; ++ig) { + psi_d[ib * lda + ig] -= dot * psi_d[jb * lda + ig]; + spsi_d[ib * lda + ig] -= dot * spsi_d[jb * lda + ig]; + } + } + if (!finite_projection) { + break; + } + } + if (!finite_projection) { + repaired = true; + continue; + } + + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::real(std::conj(psi_d[ib * lda + ig]) + * spsi_d[ib * lda + ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!std::isfinite(norm2) || norm2 <= eps) { + repaired = true; + continue; + } + + const Real inv_norm = static_cast(1.0) / std::sqrt(norm2); + for (int ig = 0; ig < nvalid; ++ig) { + psi_d[ib * lda + ig] *= inv_norm; + spsi_d[ib * lda + ig] *= inv_norm; + } + for (int ig = nvalid; ig < lda; ++ig) { + psi_d[ib * lda + ig] = static_cast(0.0); + spsi_d[ib * lda + ig] = static_cast(0.0); + } + ready = true; + } + + if (!ready) { + throw std::runtime_error("LOBPCG failed to repair rank-deficient initial subspace"); + } + } + + if (repaired) { + this->calc_hpsi_with_block(hpsi_func, this->psi.data(), this->hpsi); + this->calc_spsi_with_block(spsi_func, this->psi.data(), this->spsi); + } else { + this->calc_hpsi_with_block(hpsi_func, this->psi.data(), this->hpsi); + } + + for (int ib = 0; ib < nb; ++ib) { + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::real(std::conj(psi_d[ib * lda + ig]) + * spsi_d[ib * lda + ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!std::isfinite(norm2) || !(norm2 > eps)) { + throw std::runtime_error("LOBPCG repaired initial subspace has invalid S-norm at band " + + std::to_string(ib) + + ", norm2=" + std::to_string(norm2)); + } + } +} + +// ============================================================================ +// rayleigh_ritz (NC, S=I) +// +// psi_in is assumed orthonormal. H_sub = , heevd, rotate. +// ============================================================================ + +template +void DiagoLobpcg::rayleigh_ritz( + ct::Tensor& psi_inout, ct::Tensor& hpsi_inout, ct::Tensor& eigen_out) +{ + const int nb = this->n_band; + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + const int local_sz = this->n_band_l * nbs; + + this->pmmcn.multiply(1.0, psi_inout.data(), hpsi_inout.data(), + 0.0, this->tmp_hsub.data()); + mirror_lower(this->tmp_hsub.data(), nb, nb); + clean_hermitian_diag(this->tmp_hsub.data(), nb, nb); + + // Force exact Hermitian symmetrization. + { + T* hsub = this->tmp_hsub.data(); + for (int jj = 0; jj < nb; ++jj) { + hsub[jj * nb + jj] = T(std::real(hsub[jj * nb + jj]), 0.0); + for (int ii = jj + 1; ii < nb; ++ii) { + T a = hsub[jj * nb + ii]; + T b = std::conj(hsub[ii * nb + jj]); + T avg = static_cast(0.5) * (a + b); + hsub[jj * nb + ii] = avg; + hsub[ii * nb + jj] = std::conj(avg); + } + } + } + + try { + ct::kernels::lapack_heevd()( + nb, this->tmp_hsub.data(), nb, eigen_out.data()); + } catch (const std::exception& e) { + this->diag_log("rayleigh_ritz heevd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("H_sub", this->tmp_hsub.data(), nb, nb), + "S_sub unavailable for standard Rayleigh-Ritz"); + throw; + } + + this->rotate_wf(this->tmp_hsub, psi_inout, this->work); + syncmem_complex_op()(psi_inout.data(), this->work.data(), local_sz); + + this->rotate_wf(this->tmp_hsub, hpsi_inout, this->work); + syncmem_complex_op()(hpsi_inout.data(), this->work.data(), local_sz); +} + +template +void DiagoLobpcg::generalized_rayleigh_ritz( + ct::Tensor& psi_inout, ct::Tensor& hpsi_inout, + ct::Tensor& spsi_inout, ct::Tensor& eigen_out) +{ + const int nb = this->n_band; + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + const int local_sz = this->n_band_l * nbs; + + for (int ii = 0; ii < nb * nb; ++ii) { + this->tmp_hsub.data()[ii] = static_cast(0.0); + this->tmp_ssub.data()[ii] = static_cast(0.0); + } + + inner_product_loop(nb, nbs, nvalid, this->one_, + psi_inout.data(), hpsi_inout.data(), + this->zero_, this->tmp_hsub.data()); + inner_product_loop(nb, nbs, nvalid, this->one_, + psi_inout.data(), spsi_inout.data(), + this->zero_, this->tmp_ssub.data()); +#ifdef __MPI + Parallel_Reduce::reduce_pool(this->tmp_hsub.data(), nb * nb); + Parallel_Reduce::reduce_pool(this->tmp_ssub.data(), nb * nb); +#endif + hermitize(this->tmp_hsub.data(), nb, nb); + hermitize(this->tmp_ssub.data(), nb, nb); + + bool rr_ok = false; + std::string rr_error; + try { + ct::kernels::lapack_hegvd()( + nb, nb, + this->tmp_hsub.data(), + this->tmp_ssub.data(), + eigen_out.data(), + this->hsub.data()); + rr_ok = finite_real_block(eigen_out.data(), nb) + && finite_scalar_block(this->hsub.data(), nb * nb); + if (!rr_ok) { + this->diag_log("generalized_rayleigh_ritz hegvd returned non-finite eigen data", + hermitian_matrix_diagnostics("H_sub", this->tmp_hsub.data(), nb, nb), + hermitian_matrix_diagnostics("S_sub", this->tmp_ssub.data(), nb, nb), + vector_block_diagnostics("eigvec", + this->hsub.data(), + nb, + nb, + nb)); + } + } catch (const std::exception& e) { + rr_error = e.what(); + rr_ok = false; + } + + if (!rr_ok) { + this->diag_log("generalized_rayleigh_ritz hegvd failed" + + (rr_error.empty() ? std::string() : (": " + rr_error)), + hermitian_matrix_diagnostics("H_sub", this->tmp_hsub.data(), nb, nb), + hermitian_matrix_diagnostics("S_sub", this->tmp_ssub.data(), nb, nb), + s_overlap_diagnostics(this->tmp_ssub.data(), nb, nb)); + throw std::runtime_error("LOBPCG generalized Rayleigh-Ritz failed in hegvd"); + } + + if (std::isfinite(std::real(this->hsub.data()[0]))) { + bool large_eigvec = false; + const T* eigvec = this->hsub.data(); + for (int ii = 0; ii < nb * nb; ++ii) { + if (std::isfinite(std::real(eigvec[ii])) && std::isfinite(std::imag(eigvec[ii])) + && std::abs(eigvec[ii]) > static_cast(1.0e100)) { + large_eigvec = true; + break; + } + } + if (large_eigvec) { + this->diag_log("generalized_rayleigh_ritz huge eigenvectors before rotation", + hermitian_matrix_diagnostics("H_sub", this->tmp_hsub.data(), nb, nb), + hermitian_matrix_diagnostics("S_sub", this->tmp_ssub.data(), nb, nb), + vector_block_diagnostics("eigvec", this->hsub.data(), nb, nb, nb)); + } + } + + rotate_loop(nb, nbs, nvalid, nb, this->one_, + this->hsub.data(), psi_inout.data(), + this->zero_, this->work.data()); + syncmem_complex_op()(psi_inout.data(), this->work.data(), local_sz); + + rotate_loop(nb, nbs, nvalid, nb, this->one_, + this->hsub.data(), hpsi_inout.data(), + this->zero_, this->work.data()); + syncmem_complex_op()(hpsi_inout.data(), this->work.data(), local_sz); + + rotate_loop(nb, nbs, nvalid, nb, this->one_, + this->hsub.data(), spsi_inout.data(), + this->zero_, this->work.data()); + syncmem_complex_op()(spsi_inout.data(), this->work.data(), local_sz); + + if (!finite_vector_block(psi_inout.data(), nb, nbs, nvalid) + || !finite_vector_block(hpsi_inout.data(), nb, nbs, nvalid) + || !finite_vector_block(spsi_inout.data(), nb, nbs, nvalid)) { + this->diag_log("generalized_rayleigh_ritz rotation produced non-finite vectors", + vector_block_diagnostics("psi", psi_inout.data(), nb, nbs, nvalid), + vector_block_diagnostics("hpsi", hpsi_inout.data(), nb, nbs, nvalid), + vector_block_diagnostics("spsi", spsi_inout.data(), nb, nbs, nvalid)); + throw std::runtime_error("LOBPCG generalized Rayleigh-Ritz rotation produced non-finite vectors"); + } +} + +template +void DiagoLobpcg::generalized_rayleigh_ritz_parallel( + ct::Tensor& psi_inout, ct::Tensor& hpsi_inout, + ct::Tensor& spsi_inout, ct::Tensor& eigen_out) +{ + const int nb = this->n_band; + const int nbs = this->n_basis; + const int local_sz = this->n_band_l * nbs; + + T* hsub_d = this->hsub.data(); + T* ssub_d = this->ssub.data(); + setmem_complex_op()(hsub_d, static_cast(0.0), this->nsub * this->nsub); + setmem_complex_op()(ssub_d, static_cast(0.0), this->nsub * this->nsub); + + this->pmmcn.multiply(this->one_, + psi_inout.data(), + hpsi_inout.data(), + this->zero_, + this->tmp_hsub.data()); + this->pmmcn.multiply(this->one_, + psi_inout.data(), + spsi_inout.data(), + this->zero_, + this->tmp_ssub.data()); + for (int jc = 0; jc < nb; ++jc) { + std::copy(this->tmp_hsub.data() + jc * nb, + this->tmp_hsub.data() + jc * nb + nb, + hsub_d + jc * this->nsub); + std::copy(this->tmp_ssub.data() + jc * nb, + this->tmp_ssub.data() + jc * nb + nb, + ssub_d + jc * this->nsub); + } + hermitize(hsub_d, this->nsub, nb); + hermitize(ssub_d, this->nsub, nb); + + std::vector inv_subspace_norm; + std::string scale_error; + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, nb, + inv_subspace_norm, scale_error)) { + this->diag_log("generalized_rayleigh_ritz_parallel failed to normalize S_sub before hegvd", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, nb), + s_overlap_diagnostics(ssub_d, this->nsub, nb), + scale_error); + throw std::runtime_error("LOBPCG generalized parallel initial overlap has invalid diagonal"); + } + + const auto spd_check = check_subspace_spd(ssub_d, this->nsub, nb); + if (!spd_check.ok) { + this->diag_log("generalized_rayleigh_ritz_parallel rejected ill-conditioned S_sub before hegvd", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, nb), + s_overlap_diagnostics(ssub_d, this->nsub, nb), + subspace_spd_diagnostics(spd_check)); + throw std::runtime_error("LOBPCG generalized parallel initial overlap is ill-conditioned before hegvd"); + } + + try { + ct::kernels::lapack_hegvd()( + nb, this->nsub, hsub_d, ssub_d, eigen_out.data(), hsub_d); + } catch (const std::exception& e) { + this->diag_log("generalized_rayleigh_ritz_parallel hegvd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, nb), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, nb), + subspace_spd_diagnostics(spd_check)); + throw; + } + + if (!finite_real_block(eigen_out.data(), nb) + || !finite_scalar_block(hsub_d, nb * this->nsub)) { + this->diag_log("generalized_rayleigh_ritz_parallel hegvd produced non-finite values", + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, nb), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, nb), + subspace_spd_diagnostics(spd_check)); + throw std::runtime_error("LOBPCG generalized parallel initial diagonalization produced non-finite values"); + } + + for (int jc = 0; jc < nb; ++jc) { + for (int ir = 0; ir < nb; ++ir) { + hsub_d[jc * this->nsub + ir] *= inv_subspace_norm[ir]; + } + } + + setmem_complex_op()(this->tmp_hsub.data(), static_cast(0.0), nb * nb); + for (int jc = 0; jc < nb; ++jc) { + std::copy(hsub_d + jc * this->nsub, + hsub_d + jc * this->nsub + nb, + this->tmp_hsub.data() + jc * nb); + } + + this->plintrans.act(this->one_, + psi_inout.data(), + this->tmp_hsub.data(), + this->zero_, + this->work.data()); + syncmem_complex_op()(psi_inout.data(), this->work.data(), local_sz); + + this->plintrans.act(this->one_, + hpsi_inout.data(), + this->tmp_hsub.data(), + this->zero_, + this->work.data()); + syncmem_complex_op()(hpsi_inout.data(), this->work.data(), local_sz); + + this->plintrans.act(this->one_, + spsi_inout.data(), + this->tmp_hsub.data(), + this->zero_, + this->work.data()); + syncmem_complex_op()(spsi_inout.data(), this->work.data(), local_sz); + + if (!finite_vector_block(psi_inout.data(), this->n_band_l, nbs, this->n_dim) + || !finite_vector_block(hpsi_inout.data(), this->n_band_l, nbs, this->n_dim) + || !finite_vector_block(spsi_inout.data(), this->n_band_l, nbs, this->n_dim)) { + this->diag_log("generalized_rayleigh_ritz_parallel rotation produced non-finite vectors", + vector_block_diagnostics("psi", psi_inout.data(), this->n_band_l, nbs, this->n_dim), + vector_block_diagnostics("hpsi", hpsi_inout.data(), this->n_band_l, nbs, this->n_dim), + vector_block_diagnostics("spsi", spsi_inout.data(), this->n_band_l, nbs, this->n_dim)); + throw std::runtime_error("LOBPCG generalized parallel initial rotation produced non-finite vectors"); + } +} + +// ============================================================================ +// compute_residual — NC: R = HX - lambda*X, grad = R ./ prec +// ============================================================================ + +template +void DiagoLobpcg::compute_residual( + const ct::Tensor& psi_in, const ct::Tensor& hpsi_in, + const ct::Tensor& eigen_in, const ct::Tensor& prec_in, + ct::Tensor& grad_out, ct::Tensor& err_out) +{ + const Real* _prec = prec_in.data(); + const Real* _eigen = eigen_in.data(); + const T* _psi = psi_in.data(); + const T* _hpsi = hpsi_in.data(); + T* _grad = grad_out.data(); + Real* _err = err_out.data(); + const int band_start = this->local_band_start(); + + for (int ib = 0; ib < this->n_band_l; ib++) { + const int ioff = ib * this->n_basis; + const Real lambda = _eigen[band_start + ib]; + Real err_j = 0.0; + for (int ig = 0; ig < this->n_dim; ig++) { + const int idx = ioff + ig; + const T r = _hpsi[idx] - lambda * _psi[idx]; + _grad[idx] = r / std::max(_prec[ig], + static_cast(1e-8)); + err_j += std::norm(r); + } + for (int ig = this->n_dim; ig < this->n_basis; ig++) + _grad[ioff + ig] = static_cast(0.0); + _err[ib] = err_j; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(_err, this->n_band_l); +#endif + for (int ib = 0; ib < this->n_band_l; ib++) + _err[ib] = std::sqrt(_err[ib]); +} + +template +void DiagoLobpcg::compute_residual_s( + const ct::Tensor& psi_in, const ct::Tensor& hpsi_in, + const ct::Tensor& spsi_in, const ct::Tensor& eigen_in, + const ct::Tensor& prec_in, ct::Tensor& grad_out, ct::Tensor& err_out) +{ + const Real* _prec = prec_in.data(); + const Real* _eigen = eigen_in.data(); + const T* _hpsi = hpsi_in.data(); + const T* _spsi = spsi_in.data(); + T* _grad = grad_out.data(); + Real* _err = err_out.data(); + const int band_start = this->local_band_start(); + + for (int ib = 0; ib < this->n_band_l; ib++) { + const int ioff = ib * this->n_basis; + const Real lambda = _eigen[band_start + ib]; + Real err_j = 0.0; + for (int ig = 0; ig < this->n_dim; ig++) { + const int idx = ioff + ig; + const T r = _hpsi[idx] - lambda * _spsi[idx]; + const Real denom = std::max(_prec[ig], static_cast(1e-8)); + _grad[idx] = r / denom; + if (!std::isfinite(std::real(_grad[idx])) || !std::isfinite(std::imag(_grad[idx]))) { + std::ostringstream oss; + oss << "ib=" << ib + << " ig=" << ig + << " lambda=" << lambda + << " prec=" << _prec[ig] + << " denom=" << denom + << " hpsi=(" << std::real(_hpsi[idx]) << "," << std::imag(_hpsi[idx]) << ")" + << " spsi=(" << std::real(_spsi[idx]) << "," << std::imag(_spsi[idx]) << ")" + << " residual=(" << std::real(r) << "," << std::imag(r) << ")" + << " grad=(" << std::real(_grad[idx]) << "," << std::imag(_grad[idx]) << ")"; + this->diag_log("compute_residual_s non-finite grad", + oss.str(), + vector_block_diagnostics("hpsi", _hpsi, this->n_band_l, this->n_basis, this->n_dim), + vector_block_diagnostics("spsi", _spsi, this->n_band_l, this->n_basis, this->n_dim)); + throw std::runtime_error("LOBPCG generalized residual produced non-finite gradient"); + } + err_j += std::norm(r); + } + for (int ig = this->n_dim; ig < this->n_basis; ig++) + _grad[ioff + ig] = static_cast(0.0); + _err[ib] = err_j; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(_err, this->n_band_l); +#endif + for (int ib = 0; ib < this->n_band_l; ib++) + _err[ib] = std::sqrt(_err[ib]); +} + +// ============================================================================ +// orth_projection — grad -= psi * [S=I] +// inner(i,j) = (col-major) +// grad_j -= sum_i psi_i * inner(i,j) +// ============================================================================ + +template +void DiagoLobpcg::orth_projection( + const ct::Tensor& psi_in, ct::Tensor& hsub_work, ct::Tensor& grad_out) +{ + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + + this->pmmcn.multiply(1.0, psi_in.data(), grad_out.data(), + 0.0, hsub_work.data()); + this->plintrans.act(-1.0, psi_in.data(), hsub_work.data(), + 1.0, grad_out.data()); + T* grad = grad_out.data(); + for (int jb = 0; jb < this->n_band_l; ++jb) { + for (int ig = nvalid; ig < nbs; ++ig) { + grad[jb * nbs + ig] = static_cast(0.0); + } + } +} + +template +void DiagoLobpcg::orth_projection_s( + const ct::Tensor& psi_in, const ct::Tensor& spsi_in, + ct::Tensor& hsub_work, ct::Tensor& sgrad_out, ct::Tensor& grad_out) +{ + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + + this->pmmcn.multiply(this->one_, + psi_in.data(), + sgrad_out.data(), + this->zero_, + hsub_work.data()); + this->plintrans.act(this->neg_one_, + psi_in.data(), + hsub_work.data(), + this->one_, + grad_out.data()); + this->plintrans.act(this->neg_one_, + spsi_in.data(), + hsub_work.data(), + this->one_, + sgrad_out.data()); + + T* grad = grad_out.data(); + T* sgrad = sgrad_out.data(); + for (int jb = 0; jb < this->n_band_l; jb++) { + for (int ig = nvalid; ig < nbs; ig++) { + grad[jb * nbs + ig] = static_cast(0.0); + sgrad[jb * nbs + ig] = static_cast(0.0); + } + } +} + +template +void DiagoLobpcg::orth_projection_s_with_h( + const ct::Tensor& psi_in, const ct::Tensor& hpsi_in, + const ct::Tensor& spsi_in, ct::Tensor& hsub_work, + ct::Tensor& hpdir_out, ct::Tensor& spdir_out, ct::Tensor& pdir_out) +{ + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + + this->pmmcn.multiply(this->one_, + psi_in.data(), + spdir_out.data(), + this->zero_, + hsub_work.data()); + this->plintrans.act(this->neg_one_, + psi_in.data(), + hsub_work.data(), + this->one_, + pdir_out.data()); + this->plintrans.act(this->neg_one_, + spsi_in.data(), + hsub_work.data(), + this->one_, + spdir_out.data()); + this->plintrans.act(this->neg_one_, + hpsi_in.data(), + hsub_work.data(), + this->one_, + hpdir_out.data()); + + T* pdir = pdir_out.data(); + T* spdir = spdir_out.data(); + T* hpdir = hpdir_out.data(); + for (int jb = 0; jb < this->n_band_l; jb++) { + for (int ig = nvalid; ig < nbs; ig++) { + pdir[jb * nbs + ig] = static_cast(0.0); + spdir[jb * nbs + ig] = static_cast(0.0); + hpdir[jb * nbs + ig] = static_cast(0.0); + } + } +} + +// ============================================================================ +// rotate_wf +// ============================================================================ + +template +void DiagoLobpcg::rotate_wf( + const ct::Tensor& hsub_in, ct::Tensor& psi_out, ct::Tensor& workspace_in) +{ + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + this->plintrans.act(1.0, psi_out.data(), hsub_in.data(), + 0.0, workspace_in.data()); + T* workspace = workspace_in.data(); + for (int ib = 0; ib < this->n_band_l; ++ib) { + for (int ig = nvalid; ig < nbs; ++ig) { + workspace[ib * nbs + ig] = static_cast(0.0); + } + } + syncmem_complex_op()(psi_out.data(), workspace_in.data(), + this->n_band_l * nbs); +} + +// ============================================================================ +// orth_cholesky — S=I Cholesky orthonormalization +// ============================================================================ + +template +void DiagoLobpcg::orth_cholesky( + ct::Tensor& workspace_in, ct::Tensor& psi_out, + ct::Tensor& hpsi_out, ct::Tensor& hsub_out) +{ + const int nb = this->n_band; + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + + try { + this->pmmcn.multiply(1.0, psi_out.data(), psi_out.data(), + 0.0, hsub_out.data()); + ct::kernels::set_matrix()('L', hsub_out.data(), nb); + ct::kernels::lapack_potrf()('U', nb, hsub_out.data(), nb); + ct::kernels::lapack_trtri()('U', 'N', nb, hsub_out.data(), nb); + this->rotate_wf(hsub_out, psi_out, workspace_in); + this->rotate_wf(hsub_out, hpsi_out, workspace_in); + return; + } catch (const std::exception&) { + // Fall back to modified Gram-Schmidt when Cholesky sees a near-dependent block. + } + + T* psi_d = psi_out.data(); + T* hpsi_d = hpsi_out.data(); + const Real eps = static_cast(100) + * std::numeric_limits::epsilon(); + + for (int ib = 0; ib < nb; ++ib) { + for (int jb = 0; jb < ib; ++jb) { + T dot = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + dot += std::conj(psi_d[jb * nbs + ig]) + * psi_d[ib * nbs + ig]; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&dot, 1); +#endif + for (int ig = 0; ig < nvalid; ++ig) { + psi_d[ib * nbs + ig] -= dot * psi_d[jb * nbs + ig]; + hpsi_d[ib * nbs + ig] -= dot * hpsi_d[jb * nbs + ig]; + } + } + + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::norm(psi_d[ib * nbs + ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!(norm2 > eps)) { + throw std::runtime_error("orth_cholesky failed: dependent vectors"); + } + const Real inv_norm = static_cast(1.0) / std::sqrt(norm2); + for (int ig = 0; ig < nvalid; ++ig) { + psi_d[ib * nbs + ig] *= inv_norm; + hpsi_d[ib * nbs + ig] *= inv_norm; + } + for (int ig = nvalid; ig < nbs; ++ig) { + psi_d[ib * nbs + ig] = static_cast(0.0); + hpsi_d[ib * nbs + ig] = static_cast(0.0); + } + } +} + +template +void DiagoLobpcg::orth_cholesky_s( + ct::Tensor&, ct::Tensor& psi_out, ct::Tensor& hpsi_out, + ct::Tensor& spsi_out, ct::Tensor&) +{ + const int nb = this->n_band; + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + + T* psi_d = psi_out.data(); + T* hpsi_d = hpsi_out.data(); + T* spsi_d = spsi_out.data(); + const Real eps = static_cast(100) + * std::numeric_limits::epsilon(); + + for (int ib = 0; ib < nb; ++ib) { + for (int jb = 0; jb < ib; ++jb) { + T dot = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + dot += std::conj(psi_d[jb * nbs + ig]) + * spsi_d[ib * nbs + ig]; + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&dot, 1); +#endif + for (int ig = 0; ig < nvalid; ++ig) { + psi_d[ib * nbs + ig] -= dot * psi_d[jb * nbs + ig]; + hpsi_d[ib * nbs + ig] -= dot * hpsi_d[jb * nbs + ig]; + spsi_d[ib * nbs + ig] -= dot * spsi_d[jb * nbs + ig]; + } + } + + Real norm2 = static_cast(0.0); + for (int ig = 0; ig < nvalid; ++ig) { + norm2 += std::real(std::conj(psi_d[ib * nbs + ig]) + * spsi_d[ib * nbs + ig]); + } +#ifdef __MPI + Parallel_Reduce::reduce_pool(&norm2, 1); +#endif + if (!(norm2 > eps)) { + throw std::runtime_error("orth_cholesky_s failed: dependent vectors at band " + + std::to_string(ib) + + ", norm2=" + std::to_string(norm2) + + ", nvalid=" + std::to_string(nvalid) + + ", lda=" + std::to_string(nbs)); + } + const Real inv_norm = static_cast(1.0) / std::sqrt(norm2); + for (int ig = 0; ig < nvalid; ++ig) { + psi_d[ib * nbs + ig] *= inv_norm; + hpsi_d[ib * nbs + ig] *= inv_norm; + spsi_d[ib * nbs + ig] *= inv_norm; + } + for (int ig = nvalid; ig < nbs; ++ig) { + psi_d[ib * nbs + ig] = static_cast(0.0); + hpsi_d[ib * nbs + ig] = static_cast(0.0); + spsi_d[ib * nbs + ig] = static_cast(0.0); + } + } +} + +// ============================================================================ +// test_error +// ============================================================================ + +template +void DiagoLobpcg::validate_ethr_band(const std::vector& ethr_band) const +{ + if (ethr_band.size() != static_cast(this->n_band_l)) { + std::ostringstream oss; + oss << "LOBPCG local ethr_band size mismatch: size=" << ethr_band.size() + << ", required local bands=" << this->n_band_l + << ", global bands=" << this->n_band; + if (!this->diag_context.empty()) { + oss << ", context={" << this->diag_context << "}"; + } + throw std::invalid_argument(oss.str()); + } +} + +template +bool DiagoLobpcg::test_error( + const ct::Tensor& err_in, const std::vector& ethr_band) +{ + Real* _err_st = err_in.data(); + bool not_conv = false; + std::vector tmp_cpu; + if (err_in.device_type() == ct::DeviceType::GpuDevice) { + tmp_cpu.resize(this->n_band_l); + _err_st = tmp_cpu.data(); + syncmem_var_d2h_op()(_err_st, err_in.data(), this->n_band_l); + } + for (int ii = 0; ii < this->n_band_l; ii++) + if (_err_st[ii] > ethr_band[ii]) not_conv = true; +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, ¬_conv, 1, MPI_C_BOOL, MPI_LOR, BP_WORLD); +#endif + return not_conv; +} + +template +void DiagoLobpcg::report_not_converged( + const char* problem_type, + const int used_iter, + const int max_iter, + const std::vector& ethr_band) const +{ + const Real* err = this->err_st.data(); + int notconv = 0; + Real max_residual = static_cast(0.0); + for (int ib = 0; ib < this->n_band_l; ++ib) { + max_residual = std::max(max_residual, err[ib]); + if (err[ib] > ethr_band[ib]) { + ++notconv; + } + } +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, ¬conv, 1, MPI_INT, MPI_SUM, BP_WORLD); + const MPI_Datatype real_type = std::is_same::value ? MPI_DOUBLE : MPI_FLOAT; + MPI_Allreduce(MPI_IN_PLACE, &max_residual, 1, real_type, MPI_MAX, BP_WORLD); +#endif + if (notconv > 0) { + std::ostringstream msg; + msg << "DiagoLobpcg::diag(" << problem_type << ") reached max_iter=" + << max_iter + << " after " << used_iter + << " iterations; notconv=" << notconv + << ", max_residual=" << max_residual; + if (!this->diag_context.empty()) { + msg << ", context={" << this->diag_context << "}"; + } + std::cout << "\n " << msg.str() << std::endl; + if (this->notconv_max >= 0 && notconv > this->notconv_max) { + throw std::runtime_error(msg.str()); + } + } +} + +template +bool DiagoLobpcg::profile_enabled() const +{ + if (std::getenv("ABACUS_LOBPCG_PROFILE") != nullptr) { + return true; + } + return static_cast(this->n_basis) * this->n_band_l >= 200000; +} + +template +void DiagoLobpcg::profile_log(const char* problem_type, + const char* stage, + const int iter, + const double seconds) const +{ + if (!this->profile_enabled()) { + return; + } + std::ostringstream oss; + oss << "LOBPCG_PROFILE " << problem_type + << " iter=" << iter + << " stage=" << stage + << " seconds=" << std::fixed << std::setprecision(6) << seconds; + if (!this->diag_context.empty()) { + oss << " context={" << this->diag_context << "}"; + } + std::cout << "\n " << oss.str() << std::endl; + if (GlobalV::ofs_running.good()) { + GlobalV::ofs_running << " " << oss.str() << std::endl; + GlobalV::ofs_running.flush(); + } +} + +// ============================================================================ +// lobpcg_update — generalized R-R on subspace W = [X, Z, P] +// +// H_sub = , S_sub = +// H_sub C = S_sub C Lambda (hegvd) +// X_new = V_XX*X + V_ZX*Z + V_PX*P +// P_new = V_ZX*Z + V_PX*P (soft restart) +// ============================================================================ + +template +void DiagoLobpcg::lobpcg_update( + ct::Tensor& psi, ct::Tensor& hpsi, + ct::Tensor& grad, ct::Tensor& hgrad, + ct::Tensor& pdir, ct::Tensor& hpdir, + ct::Tensor& eigen) +{ + const int n = this->n_band; + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + const int local_sz = this->n_band_l * nbs; + const Real eps = static_cast(100) + * std::numeric_limits::epsilon(); + + if (this->n_band_l != this->n_band) { + int block_count = this->has_pdir ? 3 : 2; + int m = block_count * n; + const T* basis_blocks[3] = {psi.data(), grad.data(), pdir.data()}; + const T* hbasis_blocks[3] = {hpsi.data(), hgrad.data(), hpdir.data()}; + + T* hsub_d = this->hsub.data(); + T* ssub_d = this->ssub.data(); + + auto store_block = [=](const T* src, T* dst, const int iblock, const int jblock) { + for (int jc = 0; jc < n; ++jc) { + std::copy(src + jc * n, + src + jc * n + n, + dst + (jblock * n + jc) * this->nsub + iblock * n); + } + }; + auto store_hermitian_block = [=](const T* src, T* dst, const int iblock, const int jblock) { + store_block(src, dst, iblock, jblock); + if (iblock == jblock) { + return; + } + for (int jc = 0; jc < n; ++jc) { + for (int ir = 0; ir < n; ++ir) { + dst[(iblock * n + jc) * this->nsub + jblock * n + ir] + = std::conj(src[ir * n + jc]); + } + } + }; + + auto build_subspace = [&](const int active_blocks) { + setmem_complex_op()(hsub_d, static_cast(0.0), this->nsub * this->nsub); + setmem_complex_op()(ssub_d, static_cast(0.0), this->nsub * this->nsub); + for (int jb = 0; jb < active_blocks; ++jb) { + for (int ib = 0; ib <= jb; ++ib) { + this->pmmcn.multiply(this->one_, + basis_blocks[ib], + hbasis_blocks[jb], + this->zero_, + this->tmp_hsub.data()); + store_hermitian_block(this->tmp_hsub.data(), hsub_d, ib, jb); + + this->pmmcn.multiply(this->one_, + basis_blocks[ib], + basis_blocks[jb], + this->zero_, + this->tmp_ssub.data()); + store_hermitian_block(this->tmp_ssub.data(), ssub_d, ib, jb); + } + } + hermitize(hsub_d, this->nsub, active_blocks * n); + hermitize(ssub_d, this->nsub, active_blocks * n); + }; + + std::vector inv_subspace_norm; + std::string scale_error; + build_subspace(block_count); + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, m, + inv_subspace_norm, scale_error)) { + if (block_count == 3) { + block_count = 2; + m = block_count * n; + build_subspace(block_count); + scale_error.clear(); + } + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, m, + inv_subspace_norm, scale_error)) { + this->diag_log("lobpcg_update failed to normalize S_sub before hegvd", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + scale_error); + throw std::runtime_error("LOBPCG subspace overlap has invalid diagonal before hegvd"); + } + } + auto spd_check = check_subspace_spd(ssub_d, this->nsub, m); + if (!spd_check.ok && block_count == 3) { + if (this->profile_enabled()) { + this->diag_log("lobpcg_update dropping P due to ill-conditioned S_sub", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + subspace_spd_diagnostics(spd_check)); + } + block_count = 2; + m = block_count * n; + build_subspace(block_count); + scale_error.clear(); + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, m, + inv_subspace_norm, scale_error)) { + this->diag_log("lobpcg_update failed to normalize restarted S_sub before hegvd", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + scale_error); + throw std::runtime_error("LOBPCG restarted subspace overlap has invalid diagonal before hegvd"); + } + spd_check = check_subspace_spd(ssub_d, this->nsub, m); + } + + if (!spd_check.ok) { + this->diag_log("lobpcg_update rejected ill-conditioned S_sub before hegvd", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + subspace_spd_diagnostics(spd_check)); + throw std::runtime_error("LOBPCG subspace overlap is ill-conditioned before hegvd"); + } + + try { + ct::kernels::lapack_hegvd()( + m, this->nsub, hsub_d, ssub_d, + this->sub_eigen.data(), hsub_d); + } catch (const std::exception& e) { + this->diag_log("lobpcg_update hegvd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m)); + throw; + } + + if (!finite_real_block(this->sub_eigen.data(), m) + || !finite_scalar_block(hsub_d, m * this->nsub)) { + this->diag_log("lobpcg_update hegvd produced non-finite values", + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + subspace_spd_diagnostics(spd_check)); + throw std::runtime_error("LOBPCG subspace diagonalization produced non-finite values"); + } + for (int jc = 0; jc < n; ++jc) { + for (int ir = 0; ir < m; ++ir) { + hsub_d[jc * this->nsub + ir] *= inv_subspace_norm[ir]; + } + } + + const Real* sub = this->sub_eigen.data(); + Real* eig = eigen.data(); + for (int ib = 0; ib < n; ++ib) { + eig[ib] = sub[ib]; + } + + T* x_new = this->work.data(); + T* hx_new = this->hwork.data(); + T* p_new = this->pwork.data(); + T* hp_new = this->hpwork.data(); + setmem_complex_op()(x_new, static_cast(0.0), local_sz); + setmem_complex_op()(hx_new, static_cast(0.0), local_sz); + setmem_complex_op()(p_new, static_cast(0.0), local_sz); + setmem_complex_op()(hp_new, static_cast(0.0), local_sz); + + auto copy_coeff_block = [=](const int block, const int first_col, const int ncol, T* coeff) { + setmem_complex_op()(coeff, static_cast(0.0), n * n); + for (int jc = 0; jc < ncol; ++jc) { + std::copy(hsub_d + (first_col + jc) * this->nsub + block * n, + hsub_d + (first_col + jc) * this->nsub + block * n + n, + coeff + (first_col + jc) * n); + } + }; + + for (int ib = 0; ib < block_count; ++ib) { + copy_coeff_block(ib, 0, n, this->tmp_hsub.data()); + this->plintrans.act(this->one_, basis_blocks[ib], this->tmp_hsub.data(), + ib == 0 ? this->zero_ : this->one_, x_new); + this->plintrans.act(this->one_, hbasis_blocks[ib], this->tmp_hsub.data(), + ib == 0 ? this->zero_ : this->one_, hx_new); + } + + const int tail_cols = m - n; + if (tail_cols > 0) { + for (int ib = 1; ib < block_count; ++ib) { + copy_coeff_block(ib, 0, n, this->tmp_hsub.data()); + this->plintrans.act(this->one_, basis_blocks[ib], this->tmp_hsub.data(), + ib == 1 ? this->zero_ : this->one_, p_new); + this->plintrans.act(this->one_, hbasis_blocks[ib], this->tmp_hsub.data(), + ib == 1 ? this->zero_ : this->one_, hp_new); + } + } + + syncmem_complex_op()(psi.data(), x_new, local_sz); + syncmem_complex_op()(hpsi.data(), hx_new, local_sz); + syncmem_complex_op()(pdir.data(), p_new, local_sz); + syncmem_complex_op()(hpdir.data(), hp_new, local_sz); + + int update_invalid_mask = + (!finite_vector_block(psi.data(), this->n_band_l, nbs, this->n_dim) ? 1 : 0) + | (!finite_vector_block(hpsi.data(), this->n_band_l, nbs, this->n_dim) ? 2 : 0) + | (!finite_real_block(eigen.data(), n) ? 4 : 0); +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, &update_invalid_mask, 1, MPI_INT, MPI_BOR, BP_WORLD); +#endif + if (update_invalid_mask != 0) { + throw std::runtime_error("LOBPCG band-parallel update produced non-finite values"); + } + + this->has_pdir = true; + return; + } + + std::vector basis; + std::vector hbasis; + basis.reserve((this->has_pdir ? 3 : 2) * local_sz); + hbasis.reserve((this->has_pdir ? 3 : 2) * local_sz); + + append_orthonormal_block(n, nbs, nvalid, eps, + psi.data(), hpsi.data(), + basis, hbasis); + append_orthonormal_block(n, nbs, nvalid, eps, + grad.data(), hgrad.data(), + basis, hbasis); + if (this->has_pdir) { + append_orthonormal_block(n, nbs, nvalid, eps, + pdir.data(), hpdir.data(), + basis, hbasis); + } + + const int m = static_cast(basis.size() / nbs); + if (m < n) { + throw std::runtime_error("LOBPCG standard subspace lost rank"); + } + + T* hsub_d = this->hsub.data(); + setmem_complex_op()(hsub_d, static_cast(0.0), this->nsub * this->nsub); + + ModuleBase::gemm_op()('C', 'N', + m, m, nvalid, + this->one, + basis.data(), nbs, + hbasis.data(), nbs, + this->zero, + hsub_d, this->nsub); +#ifdef __MPI + for (int jc = 0; jc < m; ++jc) { + Parallel_Reduce::reduce_pool(hsub_d + jc * this->nsub, m); + } +#endif + hermitize(hsub_d, this->nsub, m); + + try { + ct::kernels::lapack_heevd()( + m, hsub_d, this->nsub, this->sub_eigen.data()); + } catch (const std::exception& e) { + this->diag_log("lobpcg_update heevd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + "S_sub unavailable after explicit orthonormalization"); + throw; + } + + if (!finite_real_block(this->sub_eigen.data(), m) + || !finite_scalar_block(hsub_d, m * this->nsub)) { + throw std::runtime_error("LOBPCG subspace diagonalization produced non-finite values"); + } + + const Real* sub = this->sub_eigen.data(); + Real* eig = eigen.data(); + for (int ib = 0; ib < n; ++ib) { + eig[ib] = sub[ib]; + } + + T* x_new = this->work.data(); + T* hx_new = this->hwork.data(); + T* p_new = this->pwork.data(); + T* hp_new = this->hpwork.data(); + setmem_complex_op()(x_new, static_cast(0.0), local_sz); + setmem_complex_op()(hx_new, static_cast(0.0), local_sz); + setmem_complex_op()(p_new, static_cast(0.0), local_sz); + setmem_complex_op()(hp_new, static_cast(0.0), local_sz); + + ModuleBase::gemm_op()('N', 'N', + nvalid, n, m, + this->one, + basis.data(), nbs, + hsub_d, this->nsub, + this->zero, + x_new, nbs); + ModuleBase::gemm_op()('N', 'N', + nvalid, n, m, + this->one, + hbasis.data(), nbs, + hsub_d, this->nsub, + this->zero, + hx_new, nbs); + + const int tail_cols = m - n; + if (tail_cols > 0) { + ModuleBase::gemm_op()('N', 'N', + nvalid, n, tail_cols, + this->one, + basis.data() + n * nbs, nbs, + hsub_d + n, this->nsub, + this->zero, + p_new, nbs); + ModuleBase::gemm_op()('N', 'N', + nvalid, n, tail_cols, + this->one, + hbasis.data() + n * nbs, nbs, + hsub_d + n, this->nsub, + this->zero, + hp_new, nbs); + } + + syncmem_complex_op()(psi.data(), x_new, local_sz); + syncmem_complex_op()(hpsi.data(), hx_new, local_sz); + syncmem_complex_op()(pdir.data(), p_new, local_sz); + syncmem_complex_op()(hpdir.data(), hp_new, local_sz); + + if (!finite_vector_block(psi.data(), this->n_band_l, nbs, this->n_dim) + || !finite_vector_block(hpsi.data(), this->n_band_l, nbs, this->n_dim) + || !finite_real_block(eigen.data(), n)) { + throw std::runtime_error("LOBPCG standard update produced non-finite values"); + } + + this->has_pdir = true; +} + +// ============================================================================ +// lobpcg_update_s — generalized R-R on S-orthonormalized W = [X, Z, P] +// ============================================================================ + +template +void DiagoLobpcg::lobpcg_update_s_parallel( + ct::Tensor& psi, ct::Tensor& hpsi, ct::Tensor& spsi, + ct::Tensor& grad, ct::Tensor& hgrad, ct::Tensor& sgrad, + ct::Tensor& pdir, ct::Tensor& hpdir, ct::Tensor& spdir, + ct::Tensor& eigen) +{ + const int n = this->n_band; + const int nbs = this->n_basis; + const int local_sz = this->n_band_l * nbs; + int block_count = this->has_pdir ? 3 : 2; + int m = block_count * n; + + const T* basis_blocks[3] = {psi.data(), grad.data(), pdir.data()}; + const T* hbasis_blocks[3] = {hpsi.data(), hgrad.data(), hpdir.data()}; + const T* sbasis_blocks[3] = {spsi.data(), sgrad.data(), spdir.data()}; + + T* hsub_d = this->hsub.data(); + T* ssub_d = this->ssub.data(); + + auto store_block = [=](const T* src, T* dst, const int iblock, const int jblock) { + for (int jc = 0; jc < n; ++jc) { + std::copy(src + jc * n, + src + jc * n + n, + dst + (jblock * n + jc) * this->nsub + iblock * n); + } + }; + auto store_hermitian_block = [=](const T* src, T* dst, const int iblock, const int jblock) { + store_block(src, dst, iblock, jblock); + if (iblock == jblock) { + return; + } + for (int jc = 0; jc < n; ++jc) { + for (int ir = 0; ir < n; ++ir) { + dst[(iblock * n + jc) * this->nsub + jblock * n + ir] + = std::conj(src[ir * n + jc]); + } + } + }; + + auto build_subspace = [&](const int active_blocks) { + setmem_complex_op()(hsub_d, static_cast(0.0), this->nsub * this->nsub); + setmem_complex_op()(ssub_d, static_cast(0.0), this->nsub * this->nsub); + for (int jb = 0; jb < active_blocks; ++jb) { + for (int ib = 0; ib <= jb; ++ib) { + this->pmmcn.multiply(this->one_, + basis_blocks[ib], + hbasis_blocks[jb], + this->zero_, + this->tmp_hsub.data()); + store_hermitian_block(this->tmp_hsub.data(), hsub_d, ib, jb); + + this->pmmcn.multiply(this->one_, + basis_blocks[ib], + sbasis_blocks[jb], + this->zero_, + this->tmp_ssub.data()); + store_hermitian_block(this->tmp_ssub.data(), ssub_d, ib, jb); + } + } + hermitize(hsub_d, this->nsub, active_blocks * n); + hermitize(ssub_d, this->nsub, active_blocks * n); + }; + + std::vector inv_subspace_norm; + std::string scale_error; + build_subspace(block_count); + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, m, + inv_subspace_norm, scale_error)) { + if (block_count == 3) { + block_count = 2; + m = block_count * n; + build_subspace(block_count); + scale_error.clear(); + } + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, m, + inv_subspace_norm, scale_error)) { + this->diag_log("lobpcg_update_s_parallel failed to normalize S_sub before hegvd", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + scale_error); + throw std::runtime_error("LOBPCG generalized parallel subspace overlap has invalid diagonal"); + } + } + + std::vector h_scaled(m * m, static_cast(0.0)); + std::vector s_scaled(m * m, static_cast(0.0)); + for (int jc = 0; jc < m; ++jc) { + for (int ir = 0; ir < m; ++ir) { + h_scaled[jc * m + ir] = hsub_d[jc * this->nsub + ir]; + s_scaled[jc * m + ir] = ssub_d[jc * this->nsub + ir]; + } + } + + std::vector s_evec = s_scaled; + std::vector s_eval(m, static_cast(0.0)); + try { + ct::kernels::lapack_heevd()(m, s_evec.data(), m, s_eval.data()); + } catch (const std::exception& e) { + this->diag_log("lobpcg_update_s_parallel S_sub heevd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + scale_error); + throw; + } + if (!finite_real_block(s_eval.data(), m) || !finite_scalar_block(s_evec.data(), m * m)) { + this->diag_log("lobpcg_update_s_parallel S_sub heevd produced non-finite values", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + "rank compression unavailable"); + throw std::runtime_error("LOBPCG generalized parallel overlap diagonalization produced non-finite values"); + } + + const Real s_max = s_eval.empty() ? static_cast(0.0) + : *std::max_element(s_eval.begin(), s_eval.end()); + const Real eps = std::numeric_limits::epsilon(); + const Real rank_floor = std::max(std::abs(s_max) * static_cast(1.0e-10), + static_cast(100.0) * eps * std::max(static_cast(1.0), + std::abs(s_max))); + int first_kept = 0; + while (first_kept < m && s_eval[first_kept] <= rank_floor) { + ++first_kept; + } + int rank = m - first_kept; + if (rank < n && block_count == 3) { + block_count = 2; + m = block_count * n; + build_subspace(block_count); + scale_error.clear(); + if (!scale_subspace_by_overlap_diag(hsub_d, ssub_d, this->nsub, m, + inv_subspace_norm, scale_error)) { + this->diag_log("lobpcg_update_s_parallel failed to normalize restarted S_sub before compression", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + scale_error); + throw std::runtime_error("LOBPCG generalized restarted parallel subspace overlap has invalid diagonal"); + } + h_scaled.assign(m * m, static_cast(0.0)); + s_scaled.assign(m * m, static_cast(0.0)); + for (int jc = 0; jc < m; ++jc) { + for (int ir = 0; ir < m; ++ir) { + h_scaled[jc * m + ir] = hsub_d[jc * this->nsub + ir]; + s_scaled[jc * m + ir] = ssub_d[jc * this->nsub + ir]; + } + } + s_evec = s_scaled; + s_eval.assign(m, static_cast(0.0)); + ct::kernels::lapack_heevd()(m, s_evec.data(), m, s_eval.data()); + const Real restart_s_max = s_eval.empty() ? static_cast(0.0) + : *std::max_element(s_eval.begin(), s_eval.end()); + const Real restart_rank_floor = std::max(std::abs(restart_s_max) * static_cast(1.0e-10), + static_cast(100.0) * eps * std::max(static_cast(1.0), + std::abs(restart_s_max))); + first_kept = 0; + while (first_kept < m && s_eval[first_kept] <= restart_rank_floor) { + ++first_kept; + } + rank = m - first_kept; + } + if (rank < n) { + this->diag_log("lobpcg_update_s_parallel compressed subspace rank is too small", + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m), + "rank=" + std::to_string(rank) + " required=" + std::to_string(n)); + throw std::runtime_error("LOBPCG generalized parallel compressed subspace lost rank"); + } + + std::vector q(m * rank, static_cast(0.0)); + for (int jc = 0; jc < rank; ++jc) { + const int src_col = first_kept + jc; + const Real inv_sqrt = static_cast(1.0) / std::sqrt(s_eval[src_col]); + for (int ir = 0; ir < m; ++ir) { + q[jc * m + ir] = s_evec[src_col * m + ir] * inv_sqrt; + } + } + + std::vector hq(m * rank, static_cast(0.0)); + ModuleBase::gemm_op()('N', 'N', + m, rank, m, + this->one, + h_scaled.data(), m, + q.data(), m, + this->zero, + hq.data(), m); + std::vector h_comp(rank * rank, static_cast(0.0)); + ModuleBase::gemm_op()('C', 'N', + rank, rank, m, + this->one, + q.data(), m, + hq.data(), m, + this->zero, + h_comp.data(), rank); + hermitize(h_comp.data(), rank, rank); + + try { + ct::kernels::lapack_heevd()( + rank, h_comp.data(), rank, this->sub_eigen.data()); + } catch (const std::exception& e) { + this->diag_log("lobpcg_update_s_parallel compressed heevd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + "rank=" + std::to_string(rank)); + throw; + } + if (!finite_real_block(this->sub_eigen.data(), rank) + || !finite_scalar_block(h_comp.data(), rank * rank)) { + this->diag_log("lobpcg_update_s_parallel compressed heevd produced non-finite values", + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + "rank=" + std::to_string(rank)); + throw std::runtime_error("LOBPCG generalized parallel compressed diagonalization produced non-finite values"); + } + + std::vector coeff_scaled(m * n, static_cast(0.0)); + ModuleBase::gemm_op()('N', 'N', + m, n, rank, + this->one, + q.data(), m, + h_comp.data(), rank, + this->zero, + coeff_scaled.data(), m); + setmem_complex_op()(hsub_d, static_cast(0.0), this->nsub * this->nsub); + for (int jc = 0; jc < n; ++jc) { + for (int ir = 0; ir < m; ++ir) { + hsub_d[jc * this->nsub + ir] = coeff_scaled[jc * m + ir] * inv_subspace_norm[ir]; + } + } + + const Real* sub = this->sub_eigen.data(); + Real* eig = eigen.data(); + for (int ib = 0; ib < n; ++ib) { + eig[ib] = sub[ib]; + } + + T* x_new = this->work.data(); + T* hx_new = this->hwork.data(); + T* sx_new = this->swork.data(); + T* p_new = this->pwork.data(); + T* hp_new = this->hpwork.data(); + T* sp_new = this->spwork.data(); + setmem_complex_op()(x_new, static_cast(0.0), local_sz); + setmem_complex_op()(hx_new, static_cast(0.0), local_sz); + setmem_complex_op()(sx_new, static_cast(0.0), local_sz); + setmem_complex_op()(p_new, static_cast(0.0), local_sz); + setmem_complex_op()(hp_new, static_cast(0.0), local_sz); + setmem_complex_op()(sp_new, static_cast(0.0), local_sz); + + auto copy_coeff_block = [=](const int block, T* coeff) { + setmem_complex_op()(coeff, static_cast(0.0), n * n); + for (int jc = 0; jc < n; ++jc) { + std::copy(hsub_d + jc * this->nsub + block * n, + hsub_d + jc * this->nsub + block * n + n, + coeff + jc * n); + } + }; + + for (int ib = 0; ib < block_count; ++ib) { + copy_coeff_block(ib, this->tmp_hsub.data()); + this->plintrans.act(this->one_, basis_blocks[ib], this->tmp_hsub.data(), + ib == 0 ? this->zero_ : this->one_, x_new); + this->plintrans.act(this->one_, hbasis_blocks[ib], this->tmp_hsub.data(), + ib == 0 ? this->zero_ : this->one_, hx_new); + this->plintrans.act(this->one_, sbasis_blocks[ib], this->tmp_hsub.data(), + ib == 0 ? this->zero_ : this->one_, sx_new); + } + + if (m > n) { + for (int ib = 1; ib < block_count; ++ib) { + copy_coeff_block(ib, this->tmp_hsub.data()); + this->plintrans.act(this->one_, basis_blocks[ib], this->tmp_hsub.data(), + ib == 1 ? this->zero_ : this->one_, p_new); + this->plintrans.act(this->one_, hbasis_blocks[ib], this->tmp_hsub.data(), + ib == 1 ? this->zero_ : this->one_, hp_new); + this->plintrans.act(this->one_, sbasis_blocks[ib], this->tmp_hsub.data(), + ib == 1 ? this->zero_ : this->one_, sp_new); + } + } + + syncmem_complex_op()(psi.data(), x_new, local_sz); + syncmem_complex_op()(hpsi.data(), hx_new, local_sz); + syncmem_complex_op()(spsi.data(), sx_new, local_sz); + syncmem_complex_op()(pdir.data(), p_new, local_sz); + syncmem_complex_op()(hpdir.data(), hp_new, local_sz); + syncmem_complex_op()(spdir.data(), sp_new, local_sz); + + bool psi_invalid = !finite_vector_block(psi.data(), this->n_band_l, nbs, this->n_dim); + bool hpsi_invalid = !finite_vector_block(hpsi.data(), this->n_band_l, nbs, this->n_dim); + bool spsi_invalid = !finite_vector_block(spsi.data(), this->n_band_l, nbs, this->n_dim); + bool eigen_invalid = !finite_real_block(eigen.data(), n); + int update_invalid_mask = (psi_invalid ? 1 : 0) + | (hpsi_invalid ? 2 : 0) + | (spsi_invalid ? 4 : 0) + | (eigen_invalid ? 8 : 0); +#ifdef __MPI + MPI_Allreduce(MPI_IN_PLACE, &update_invalid_mask, 1, MPI_INT, MPI_BOR, BP_WORLD); +#endif + if (update_invalid_mask != 0) { + this->diag_log("lobpcg_update_s_parallel produced non-finite values", + vector_block_diagnostics("psi", psi.data(), this->n_band_l, nbs, this->n_dim), + vector_block_diagnostics("hpsi", hpsi.data(), this->n_band_l, nbs, this->n_dim), + vector_block_diagnostics("spsi", spsi.data(), this->n_band_l, nbs, this->n_dim) + + ", eigen_invalid=" + std::to_string((update_invalid_mask & 8) != 0) + + ", invalid_mask=" + std::to_string(update_invalid_mask)); + throw std::runtime_error("LOBPCG generalized parallel update produced non-finite values"); + } + + this->has_pdir = true; +} + +template +void DiagoLobpcg::lobpcg_update_s( + ct::Tensor& psi, ct::Tensor& hpsi, ct::Tensor& spsi, + ct::Tensor& grad, ct::Tensor& hgrad, ct::Tensor& sgrad, + ct::Tensor& pdir, ct::Tensor& hpdir, ct::Tensor& spdir, + ct::Tensor& eigen) +{ + const int n = this->n_band; + const int nbs = this->n_basis; + const int nvalid = this->n_dim; + const int local_sz = this->n_band_l * nbs; + const Real eps = static_cast(100) + * std::numeric_limits::epsilon(); + + std::vector basis; + std::vector hbasis; + std::vector sbasis; + basis.reserve((this->has_pdir ? 3 : 2) * local_sz); + hbasis.reserve((this->has_pdir ? 3 : 2) * local_sz); + sbasis.reserve((this->has_pdir ? 3 : 2) * local_sz); + + append_s_orthonormal_block(n, nbs, nvalid, eps, + psi.data(), hpsi.data(), spsi.data(), + basis, hbasis, sbasis); + append_s_orthonormal_block(n, nbs, nvalid, eps, + grad.data(), hgrad.data(), sgrad.data(), + basis, hbasis, sbasis); + if (this->has_pdir) { + append_s_orthonormal_block(n, nbs, nvalid, eps, + pdir.data(), hpdir.data(), spdir.data(), + basis, hbasis, sbasis); + } + + const int m = static_cast(basis.size() / nbs); + if (m < n) { + throw std::runtime_error("LOBPCG generalized subspace lost rank"); + } + + T* hsub_d = this->hsub.data(); + T* ssub_d = this->ssub.data(); + setmem_complex_op()(hsub_d, static_cast(0.0), this->nsub * this->nsub); + setmem_complex_op()(ssub_d, static_cast(0.0), this->nsub * this->nsub); + + ModuleBase::gemm_op()('C', 'N', + m, m, nvalid, + this->one, + basis.data(), nbs, + hbasis.data(), nbs, + this->zero, + hsub_d, this->nsub); + ModuleBase::gemm_op()('C', 'N', + m, m, nvalid, + this->one, + basis.data(), nbs, + sbasis.data(), nbs, + this->zero, + ssub_d, this->nsub); +#ifdef __MPI + for (int jc = 0; jc < m; ++jc) { + Parallel_Reduce::reduce_pool(hsub_d + jc * this->nsub, m); + Parallel_Reduce::reduce_pool(ssub_d + jc * this->nsub, m); + } +#endif + hermitize(hsub_d, this->nsub, m); + hermitize(ssub_d, this->nsub, m); + + try { + ct::kernels::lapack_hegvd()( + m, this->nsub, hsub_d, ssub_d, + this->sub_eigen.data(), hsub_d); + } catch (const std::exception& e) { + this->diag_log("lobpcg_update_s hegvd failed: " + std::string(e.what()), + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m)); + throw; + } + + if (!finite_real_block(this->sub_eigen.data(), m) + || !finite_scalar_block(hsub_d, m * this->nsub)) { + this->diag_log("lobpcg_update_s hegvd produced non-finite values", + hermitian_matrix_diagnostics("H_sub", hsub_d, this->nsub, m), + hermitian_matrix_diagnostics("S_sub", ssub_d, this->nsub, m), + s_overlap_diagnostics(ssub_d, this->nsub, m)); + throw std::runtime_error("LOBPCG generalized subspace diagonalization produced non-finite values"); + } + + const Real* sub = this->sub_eigen.data(); + Real* eig = eigen.data(); + for (int ib = 0; ib < n; ++ib) { + eig[ib] = sub[ib]; + } + + T* x_new = this->work.data(); + T* hx_new = this->hwork.data(); + T* sx_new = this->swork.data(); + T* p_new = this->pwork.data(); + T* hp_new = this->hpwork.data(); + T* sp_new = this->spwork.data(); + setmem_complex_op()(x_new, static_cast(0.0), local_sz); + setmem_complex_op()(hx_new, static_cast(0.0), local_sz); + setmem_complex_op()(sx_new, static_cast(0.0), local_sz); + setmem_complex_op()(p_new, static_cast(0.0), local_sz); + setmem_complex_op()(hp_new, static_cast(0.0), local_sz); + setmem_complex_op()(sp_new, static_cast(0.0), local_sz); + + ModuleBase::gemm_op()('N', 'N', + nvalid, n, m, + this->one, + basis.data(), nbs, + hsub_d, this->nsub, + this->zero, + x_new, nbs); + ModuleBase::gemm_op()('N', 'N', + nvalid, n, m, + this->one, + hbasis.data(), nbs, + hsub_d, this->nsub, + this->zero, + hx_new, nbs); + ModuleBase::gemm_op()('N', 'N', + nvalid, n, m, + this->one, + sbasis.data(), nbs, + hsub_d, this->nsub, + this->zero, + sx_new, nbs); + + const int tail_cols = m - n; + if (tail_cols > 0) { + ModuleBase::gemm_op()('N', 'N', + nvalid, n, tail_cols, + this->one, + basis.data() + n * nbs, nbs, + hsub_d + n, this->nsub, + this->zero, + p_new, nbs); + ModuleBase::gemm_op()('N', 'N', + nvalid, n, tail_cols, + this->one, + hbasis.data() + n * nbs, nbs, + hsub_d + n, this->nsub, + this->zero, + hp_new, nbs); + ModuleBase::gemm_op()('N', 'N', + nvalid, n, tail_cols, + this->one, + sbasis.data() + n * nbs, nbs, + hsub_d + n, this->nsub, + this->zero, + sp_new, nbs); + } + + syncmem_complex_op()(psi.data(), x_new, local_sz); + syncmem_complex_op()(hpsi.data(), hx_new, local_sz); + syncmem_complex_op()(spsi.data(), sx_new, local_sz); + syncmem_complex_op()(pdir.data(), p_new, local_sz); + syncmem_complex_op()(hpdir.data(), hp_new, local_sz); + syncmem_complex_op()(spdir.data(), sp_new, local_sz); + + if (!finite_vector_block(psi.data(), this->n_band_l, nbs, this->n_dim) + || !finite_vector_block(hpsi.data(), this->n_band_l, nbs, this->n_dim) + || !finite_vector_block(spsi.data(), this->n_band_l, nbs, this->n_dim) + || !finite_real_block(eigen.data(), n)) { + throw std::runtime_error("LOBPCG generalized update produced non-finite values"); + } + + this->has_pdir = true; +} + +// ============================================================================ +// diag — main LOBPCG loop (NC, S=I) +// ============================================================================ + +template +void DiagoLobpcg::diag( + const HPsiFunc& hpsi_func, T* psi_in, + Real* eigenvalue_in, const std::vector& ethr_band) +{ + this->validate_ethr_band(ethr_band); + + this->has_pdir = false; + const int scf_iter = DiagoIterAssist::SCF_ITER; + + this->psi = ct::TensorMap(psi_in, t_type, dev_type, + {this->n_band_l, this->n_basis}); + + auto t0 = LobpcgClock::now(); + this->calc_prec(); + this->profile_log("S=I", "initial_calc_prec", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + + t0 = LobpcgClock::now(); + this->calc_hpsi_with_block(hpsi_func, psi_in, this->hpsi); + this->profile_log("S=I", "initial_hpsi", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + // Re-orthonormalize before initial R-R so H_sub is well-conditioned + t0 = LobpcgClock::now(); + this->orth_cholesky(this->work, this->psi, this->hpsi, this->tmp_hsub); + this->profile_log("S=I", "initial_orth", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + t0 = LobpcgClock::now(); + this->rayleigh_ritz(this->psi, this->hpsi, this->eigen); + this->profile_log("S=I", "initial_rr", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + + setmem_complex_op()(this->pdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + + const int default_max_iter = (scf_iter > 1) ? this->nline : (this->nline * 20); + const int max_iter = (this->max_iter > 0) ? this->max_iter : default_max_iter; + int used_iter = 0; + + for (int ntry = 0; ntry < max_iter; ++ntry) { + used_iter = ntry + 1; + t0 = LobpcgClock::now(); + this->compute_residual(this->psi, this->hpsi, this->eigen, + this->prec, this->grad, this->err_st); + this->profile_log("S=I", "residual", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + if (!this->test_error(this->err_st, ethr_band)) + break; + + const int psi_sz = this->n_basis * this->n_band_l; + const int eig_sz = this->n_band; + + t0 = LobpcgClock::now(); + this->orth_projection(this->psi, this->tmp_hsub, this->grad); + this->profile_log("S=I", "grad_projection", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + + t0 = LobpcgClock::now(); + this->calc_hpsi_with_block(hpsi_func, this->grad.data(), this->hgrad); + this->profile_log("S=I", "grad_hpsi", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + + // Backup stable state in case lobpcg_update corrupts psi/hpsi + std::vector psi_bak(psi_sz), hpsi_bak(psi_sz); + std::vector eigen_bak(eig_sz); + std::copy(this->psi.data(), this->psi.data() + psi_sz, psi_bak.data()); + std::copy(this->hpsi.data(), this->hpsi.data() + psi_sz, hpsi_bak.data()); + std::copy(this->eigen.data(), this->eigen.data() + eig_sz, eigen_bak.data()); + + try { + t0 = LobpcgClock::now(); + this->lobpcg_update(this->psi, this->hpsi, + this->grad, this->hgrad, + this->pdir, this->hpdir, + this->eigen); + this->profile_log("S=I", "lobpcg_update", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } catch (const std::exception& e1) { + std::copy(psi_bak.data(), psi_bak.data() + psi_sz, this->psi.data()); + std::copy(hpsi_bak.data(), hpsi_bak.data() + psi_sz, this->hpsi.data()); + std::copy(eigen_bak.data(), eigen_bak.data() + eig_sz, this->eigen.data()); + + setmem_complex_op()(this->pdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), psi_sz); + this->has_pdir = false; + + try { + t0 = LobpcgClock::now(); + this->lobpcg_update(this->psi, this->hpsi, + this->grad, this->hgrad, + this->pdir, this->hpdir, + this->eigen); + this->profile_log("S=I", "lobpcg_update_retry", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } catch (const std::exception& e2) { + std::copy(psi_bak.data(), psi_bak.data() + psi_sz, this->psi.data()); + std::copy(hpsi_bak.data(), hpsi_bak.data() + psi_sz, this->hpsi.data()); + std::copy(eigen_bak.data(), eigen_bak.data() + eig_sz, this->eigen.data()); + + t0 = LobpcgClock::now(); + this->calc_hpsi_with_block(hpsi_func, this->psi.data(), this->hpsi); + this->orth_cholesky(this->work, this->psi, this->hpsi, this->tmp_hsub); + this->rayleigh_ritz(this->psi, this->hpsi, this->eigen); + this->profile_log("S=I", "fallback_rr", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } + } + + const bool has_next_iteration = (ntry + 1) < max_iter; + const bool restart_next = has_next_iteration && scf_iter == 1 && ((ntry + 1) % this->nline == 0); + if (restart_next) { + setmem_complex_op()(this->pdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + this->has_pdir = false; + } + } + + t0 = LobpcgClock::now(); + this->compute_residual(this->psi, this->hpsi, this->eigen, + this->prec, this->grad, this->err_st); + this->profile_log("S=I", "final_residual", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + this->report_not_converged("S=I", used_iter, max_iter, ethr_band); + DiagoIterAssist::avg_iter += static_cast(used_iter); + + syncmem_var_d2h_op()(eigenvalue_in, + this->eigen.data() + this->local_band_start(), + this->n_band_l); +} + +template +void DiagoLobpcg::diag( + const HPsiFunc& hpsi_func, const SPsiFunc& spsi_func, T* psi_in, + Real* eigenvalue_in, const std::vector& ethr_band) +{ + this->validate_ethr_band(ethr_band); + + this->has_pdir = false; + this->psi = ct::TensorMap(psi_in, t_type, dev_type, + {this->n_band_l, this->n_basis}); + + this->calc_spsi_with_block(spsi_func, psi_in, this->spsi); + { + const T* spsi_d = this->spsi.data(); + Real max_diff = static_cast(0.0); + Real max_ref = static_cast(0.0); + for (int ib = 0; ib < this->n_band_l; ++ib) { + const int ioff = ib * this->n_basis; + for (int ig = 0; ig < this->n_dim; ++ig) { + const int idx = ioff + ig; + max_diff = std::max(max_diff, + static_cast(std::abs(spsi_d[idx] - psi_in[idx]))); + max_ref = std::max(max_ref, + static_cast(std::abs(psi_in[idx]))); + } + } +#ifdef __MPI + if (this->n_band_l != this->n_band) { + const MPI_Datatype real_type = std::is_same::value ? MPI_DOUBLE : MPI_FLOAT; + MPI_Allreduce(MPI_IN_PLACE, &max_diff, 1, real_type, MPI_MAX, BP_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &max_ref, 1, real_type, MPI_MAX, BP_WORLD); + } +#endif + const Real tol = static_cast(1.0e-12) + * std::max(static_cast(1.0), max_ref); + if (max_diff <= tol) { + this->diag(hpsi_func, psi_in, eigenvalue_in, ethr_band); + return; + } + } + + const int scf_iter = DiagoIterAssist::SCF_ITER; + + auto t0 = LobpcgClock::now(); + this->calc_prec(); + this->profile_log("S!=I", "initial_calc_prec", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + if (this->n_band_l != this->n_band) { + t0 = LobpcgClock::now(); + this->calc_hpsi_with_block(hpsi_func, this->psi.data(), this->hpsi); + this->profile_log("S!=I", "initial_hpsi", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + t0 = LobpcgClock::now(); + this->generalized_rayleigh_ritz_parallel(this->psi, this->hpsi, this->spsi, this->eigen); + this->profile_log("S!=I", "initial_rr_parallel", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } else { + t0 = LobpcgClock::now(); + this->repair_initial_subspace_s(hpsi_func, spsi_func); + this->profile_log("S!=I", "initial_repair", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + t0 = LobpcgClock::now(); + this->generalized_rayleigh_ritz(this->psi, this->hpsi, this->spsi, this->eigen); + this->profile_log("S!=I", "initial_rr", 0, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } + + setmem_complex_op()(this->pdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + setmem_complex_op()(this->spdir.data(), static_cast(0.0), + this->n_basis * this->n_band_l); + + const int default_max_iter = (scf_iter > 1) ? this->nline : (this->nline * 20); + const int max_iter = (this->max_iter > 0) ? this->max_iter : default_max_iter; + int used_iter = 0; + std::vector effective_ethr_band = ethr_band; + if (this->notconv_max < 0) { + // SCF can refine the density across outer iterations; avoid chasing a tiny diagonalization threshold. + constexpr double scf_generalized_residual_floor = 1.0e-5; + for (double& ethr : effective_ethr_band) { + ethr = std::max(ethr, scf_generalized_residual_floor); + } + } + + for (int ntry = 0; ntry < max_iter; ++ntry) { + used_iter = ntry + 1; + t0 = LobpcgClock::now(); + this->compute_residual_s(this->psi, this->hpsi, this->spsi, this->eigen, + this->prec, this->grad, this->err_st); + this->profile_log("S!=I", "residual", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + if (!this->test_error(this->err_st, effective_ethr_band)) + break; + + const int psi_sz = this->n_basis * this->n_band_l; + const int eig_sz = this->n_band; + + t0 = LobpcgClock::now(); + this->calc_spsi_with_block(spsi_func, this->grad.data(), this->sgrad); + this->profile_log("S!=I", "grad_spsi", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + t0 = LobpcgClock::now(); + this->orth_projection_s(this->psi, this->spsi, this->tmp_hsub, + this->sgrad, this->grad); + this->profile_log("S!=I", "grad_projection", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + t0 = LobpcgClock::now(); + this->calc_hpsi_with_block(hpsi_func, this->grad.data(), this->hgrad); + this->profile_log("S!=I", "grad_hpsi", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + + std::vector psi_bak(psi_sz), hpsi_bak(psi_sz), spsi_bak(psi_sz); + std::vector eigen_bak(eig_sz); + std::copy(this->psi.data(), this->psi.data() + psi_sz, psi_bak.data()); + std::copy(this->hpsi.data(), this->hpsi.data() + psi_sz, hpsi_bak.data()); + std::copy(this->spsi.data(), this->spsi.data() + psi_sz, spsi_bak.data()); + std::copy(this->eigen.data(), this->eigen.data() + eig_sz, eigen_bak.data()); + + try { + t0 = LobpcgClock::now(); + if (this->n_band_l != this->n_band) { + this->lobpcg_update_s_parallel(this->psi, this->hpsi, this->spsi, + this->grad, this->hgrad, this->sgrad, + this->pdir, this->hpdir, this->spdir, + this->eigen); + } else { + this->lobpcg_update_s(this->psi, this->hpsi, this->spsi, + this->grad, this->hgrad, this->sgrad, + this->pdir, this->hpdir, this->spdir, + this->eigen); + } + this->profile_log("S!=I", "lobpcg_update", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } catch (const std::exception& e1) { + this->diag_log("lobpcg_update_s failed: " + std::string(e1.what()), + "retry without previous search direction", + "iteration=" + std::to_string(used_iter)); + std::copy(psi_bak.data(), psi_bak.data() + psi_sz, this->psi.data()); + std::copy(hpsi_bak.data(), hpsi_bak.data() + psi_sz, this->hpsi.data()); + std::copy(spsi_bak.data(), spsi_bak.data() + psi_sz, this->spsi.data()); + std::copy(eigen_bak.data(), eigen_bak.data() + eig_sz, this->eigen.data()); + + setmem_complex_op()(this->pdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->spdir.data(), static_cast(0.0), psi_sz); + this->has_pdir = false; + + try { + t0 = LobpcgClock::now(); + if (this->n_band_l != this->n_band) { + this->lobpcg_update_s_parallel(this->psi, this->hpsi, this->spsi, + this->grad, this->hgrad, this->sgrad, + this->pdir, this->hpdir, this->spdir, + this->eigen); + } else { + this->lobpcg_update_s(this->psi, this->hpsi, this->spsi, + this->grad, this->hgrad, this->sgrad, + this->pdir, this->hpdir, this->spdir, + this->eigen); + } + this->profile_log("S!=I", "lobpcg_update_retry", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } catch (const std::exception& e2) { + this->diag_log("lobpcg_update_s retry failed: " + std::string(e2.what()), + "fallback to Rayleigh-Ritz repair", + "iteration=" + std::to_string(used_iter)); + std::copy(psi_bak.data(), psi_bak.data() + psi_sz, this->psi.data()); + std::copy(hpsi_bak.data(), hpsi_bak.data() + psi_sz, this->hpsi.data()); + std::copy(spsi_bak.data(), spsi_bak.data() + psi_sz, this->spsi.data()); + std::copy(eigen_bak.data(), eigen_bak.data() + eig_sz, this->eigen.data()); + + t0 = LobpcgClock::now(); + this->calc_hpsi_with_block(hpsi_func, this->psi.data(), this->hpsi); + this->calc_spsi_with_block(spsi_func, this->psi.data(), this->spsi); + if (this->n_band_l != this->n_band) { + this->generalized_rayleigh_ritz_parallel(this->psi, this->hpsi, this->spsi, this->eigen); + } else { + this->generalized_rayleigh_ritz(this->psi, this->hpsi, this->spsi, this->eigen); + } + this->profile_log("S!=I", "fallback_rr", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } + } + + const bool has_next_iteration = (ntry + 1) < max_iter; + const bool restart_next = has_next_iteration && scf_iter == 1 && ((ntry + 1) % this->nline == 0); + if (has_next_iteration && !restart_next) { + try { + t0 = LobpcgClock::now(); + if (this->n_band_l != this->n_band) { + this->orth_projection_s_with_h(this->psi, this->hpsi, this->spsi, + this->tmp_hsub, this->hpdir, + this->spdir, this->pdir); + } else { + this->calc_spsi_with_block(spsi_func, this->pdir.data(), this->spdir); + this->orth_projection_s(this->psi, this->spsi, this->tmp_hsub, + this->spdir, this->pdir); + this->calc_hpsi_with_block(hpsi_func, this->pdir.data(), this->hpdir); + } + this->profile_log("S!=I", "p_projection", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + } catch (const std::exception&) { + std::copy(psi_bak.data(), psi_bak.data() + psi_sz, this->psi.data()); + std::copy(hpsi_bak.data(), hpsi_bak.data() + psi_sz, this->hpsi.data()); + std::copy(spsi_bak.data(), spsi_bak.data() + psi_sz, this->spsi.data()); + std::copy(eigen_bak.data(), eigen_bak.data() + eig_sz, this->eigen.data()); + setmem_complex_op()(this->pdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->spdir.data(), static_cast(0.0), psi_sz); + this->has_pdir = false; + } + } + if (restart_next) { + setmem_complex_op()(this->pdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->hpdir.data(), static_cast(0.0), psi_sz); + setmem_complex_op()(this->spdir.data(), static_cast(0.0), psi_sz); + this->has_pdir = false; + } + } + + t0 = LobpcgClock::now(); + this->compute_residual_s(this->psi, this->hpsi, this->spsi, this->eigen, + this->prec, this->grad, this->err_st); + this->profile_log("S!=I", "final_residual", used_iter, + std::chrono::duration(LobpcgClock::now() - t0).count()); + this->report_not_converged("S!=I", used_iter, max_iter, effective_ethr_band); + DiagoIterAssist::avg_iter += static_cast(used_iter); + + syncmem_var_d2h_op()(eigenvalue_in, + this->eigen.data() + this->local_band_start(), + this->n_band_l); +} + +template class DiagoLobpcg, base_device::DEVICE_CPU>; + +} // namespace hsolver diff --git a/source/source_hsolver/diago_lobpcg.h b/source/source_hsolver/diago_lobpcg.h new file mode 100644 index 00000000000..773ccb507ab --- /dev/null +++ b/source/source_hsolver/diago_lobpcg.h @@ -0,0 +1,318 @@ +#ifndef DIAGO_LOBPCG_H_ +#define DIAGO_LOBPCG_H_ + +#include +#include +#include +#include +#include + +#include "source_base/kernels/math_kernel_op.h" +#include "source_base/module_device/memory_op.h" +#include "source_base/module_device/types.h" +#include "source_base/para_gemm.h" +#include "source_hamilt/hamilt.h" +#include "source_hsolver/kernels/hegvd_op.h" +#include "source_hsolver/para_linear_transform.h" + +#include +#include +#include + +namespace hsolver { + +/** + * @class DiagoLobpcg + * @brief Locally Optimal Block Preconditioned Conjugate Gradient eigensolver. + * + * LOBPCG maintains a block Rayleigh-Ritz subspace: + * - First iteration: W = [X, Z] (2-block, no valid P yet) + * - Subsequent: W = [X, Z, P] (3-block) + * where X = current eigenvectors, Z = preconditioned residual, P = search directions. + * + * @note Currently supports CPU only. + * GPU support is planned for subsequent phases. + * + * @tparam T Complex floating-point type. + * @tparam Device Must be base_device::DEVICE_CPU (GPU not yet supported). + */ +template , typename Device = base_device::DEVICE_CPU> +class DiagoLobpcg +{ + private: + static_assert(std::is_same::value, + "DiagoLobpcg currently supports CPU only."); + + using Real = typename GetTypeReal::type; + + public: + /// @brief H * psi -> hpsi. + using HPsiFunc = std::function; + + /// @brief S * psi -> spsi. + using SPsiFunc = std::function; + + /// Constructor — stores host preconditioner pointer. + explicit DiagoLobpcg(const Real* precondition); + + ~DiagoLobpcg(); + + /// Allocate workspace and bind h_prec TensorMap (after n_basis is known). + void init_iter(const int nband, const int nband_l, const int nbasis, const int ndim); + + /// Set max inner iterations per SCF step (default 4). + void set_nline(const int n) { this->nline = n; } + + /// Set hard iteration limit from pw_diag_nmax. Non-positive keeps legacy default. + void set_max_iter(const int n) { this->max_iter = n; } + + /// Set allowed unconverged bands after max_iter. Negative means report only. + void set_notconv_max(const int n) { this->notconv_max = n; } + + void set_diag_context(const std::string& context) { this->diag_context = context; } + + /// Generalized diagonalization. The standard problem is covered by S = I. + void diag(const HPsiFunc& hpsi_func, + const SPsiFunc& spsi_func, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band); + + private: + // ---- dimensions ---- + int n_band = 0; ///< total bands (global) + int n_band_l = 0; ///< local bands + int n_basis = 0; ///< basis functions (lda of psi) + int n_dim = 0; ///< valid dimension (= current_ngk) + int nline = 4; ///< max inner iterations per SCF step + int max_iter = 0; ///< hard iteration limit; <=0 uses nline-based default + int notconv_max = -1; ///< allowed unconverged bands; <0 reports only + int nsub = 0; ///< physical leading dim of hsub (= 3*n_band) + bool has_pdir = false; ///< true when P block holds valid directions + std::string diag_context; + + // ---- parallel ops ---- + ModuleBase::PGemmCN pmmcn; + PLinearTransform plintrans; + + // ---- type traits ---- + ct::DataType r_type = ct::DataType::DT_INVALID; + ct::DataType t_type = ct::DataType::DT_INVALID; + ct::DeviceType dev_type = ct::DeviceType::UnKnown; + + // ---- preconditioner ---- + ct::Tensor prec = {}; ///< device copy [n_basis] + const Real* h_prec_ptr = nullptr; ///< host pointer (saved in ctor) + ct::Tensor h_prec = {}; ///< host TensorMap [n_basis] (bound in init_iter) + + // ---- eigenvalues & convergence ---- + ct::Tensor eigen = {}; ///< output eigenvalues [n_band] + ct::Tensor sub_eigen = {}; ///< subspace eigenvalues [nsub] = [3*n_band] + ct::Tensor err_st = {}; ///< residual norm per band [n_band_l] + + // ---- core blocks ---- + // Layout for {n_band_l, n_basis} tensors: + // data[ib * n_basis + ig] — band-major contiguous. + // BLAS view: n_basis rows × n_band_l cols, column-major. + ct::Tensor psi = {}; ///< X (TensorMap → psi_in, no ownership) + ct::Tensor hpsi = {}; ///< HX + ct::Tensor spsi = {}; ///< SX + ct::Tensor grad = {}; ///< Z = T(R) + ct::Tensor hgrad = {}; ///< HZ + ct::Tensor sgrad = {}; ///< SZ + ct::Tensor pdir = {}; ///< P + ct::Tensor hpdir = {}; ///< HP + ct::Tensor spdir = {}; ///< SP + + // ---- subspace matrices ---- + ct::Tensor hsub = {}; ///< H_sub [nsub × nsub] + ct::Tensor ssub = {}; ///< S_sub [nsub × nsub] + + // ---- workspace ---- + ct::Tensor work = {}; ///< [n_band_l, n_basis] + ct::Tensor hwork = {}; ///< [n_band_l, n_basis] + ct::Tensor swork = {}; ///< [n_band_l, n_basis] + ct::Tensor pwork = {}; ///< [n_band_l, n_basis] (P update) + ct::Tensor hpwork = {}; ///< [n_band_l, n_basis] (HP update) + ct::Tensor spwork = {}; ///< [n_band_l, n_basis] (SP update) + ct::Tensor tmp_hsub = {}; ///< scratch [n_band, n_band] + ct::Tensor tmp_ssub = {}; ///< scratch [n_band, n_band] + + // ---- GEMM constants (following BPCG pattern) ---- + Device* ctx = {}; + const T one_ = static_cast(1.0); + const T zero_ = static_cast(0.0); + const T neg_one_ = static_cast(-1.0); + const T* one = nullptr; + const T* zero = nullptr; + const T* neg_one = nullptr; + + // ---- helpers ---- + + void calc_prec(); + + void calc_hpsi_with_block(const HPsiFunc& hpsi_func, + T* psi_in, + ct::Tensor& hpsi_out); + + void calc_spsi_with_block(const SPsiFunc& spsi_func, + const T* psi_in, + ct::Tensor& spsi_out); + + void repair_initial_subspace_s(const HPsiFunc& hpsi_func, + const SPsiFunc& spsi_func); + + /// Standard R-R: H_sub = psi^H * hpsi → heevd → rotate. + /// multiply(hpsi, psi) = psi^H * hpsi. + void rayleigh_ritz(ct::Tensor& psi_inout, + ct::Tensor& hpsi_inout, + ct::Tensor& eigen_out); + + /// Generalized R-R. + void generalized_rayleigh_ritz(ct::Tensor& psi_inout, + ct::Tensor& hpsi_inout, + ct::Tensor& spsi_inout, + ct::Tensor& eigen_out); + + /// Distributed generalized R-R for band-parallel S != I. + void generalized_rayleigh_ritz_parallel(ct::Tensor& psi_inout, + ct::Tensor& hpsi_inout, + ct::Tensor& spsi_inout, + ct::Tensor& eigen_out); + + /// NC residual: R = HX - X*Lambda, Z = R ./ prec. + /// CPU-only: direct loops. + void compute_residual(const ct::Tensor& psi_in, + const ct::Tensor& hpsi_in, + const ct::Tensor& eigen_in, + const ct::Tensor& prec_in, + ct::Tensor& grad_out, + ct::Tensor& err_out); + + /// Generalized residual. + void compute_residual_s(const ct::Tensor& psi_in, + const ct::Tensor& hpsi_in, + const ct::Tensor& spsi_in, + const ct::Tensor& eigen_in, + const ct::Tensor& prec_in, + ct::Tensor& grad_out, + ct::Tensor& err_out); + + /// grad -= psi * (psi^H * grad). multiply(grad, psi) = psi^H * grad. + void orth_projection(const ct::Tensor& psi_in, + ct::Tensor& hsub_work, + ct::Tensor& grad_out); + + /// S-orthogonalize. + void orth_projection_s(const ct::Tensor& psi_in, + const ct::Tensor& spsi_in, + ct::Tensor& hsub_work, + ct::Tensor& sgrad_out, + ct::Tensor& grad_out); + + void orth_projection_s_with_h(const ct::Tensor& psi_in, + const ct::Tensor& hpsi_in, + const ct::Tensor& spsi_in, + ct::Tensor& hsub_work, + ct::Tensor& hpdir_out, + ct::Tensor& spdir_out, + ct::Tensor& pdir_out); + + /// Core subspace update (2-block first, then 3-block). + /// Orthonormalizes W = [X, Z, P] before Rayleigh-Ritz for stability. + void lobpcg_update(ct::Tensor& psi, + ct::Tensor& hpsi, + ct::Tensor& grad, + ct::Tensor& hgrad, + ct::Tensor& pdir, + ct::Tensor& hpdir, + ct::Tensor& eigen); + + /// Generalized subspace update for S != I. + void lobpcg_update_s(ct::Tensor& psi, + ct::Tensor& hpsi, + ct::Tensor& spsi, + ct::Tensor& grad, + ct::Tensor& hgrad, + ct::Tensor& sgrad, + ct::Tensor& pdir, + ct::Tensor& hpdir, + ct::Tensor& spdir, + ct::Tensor& eigen); + + /// Distributed generalized subspace update for band-parallel S != I. + void lobpcg_update_s_parallel(ct::Tensor& psi, + ct::Tensor& hpsi, + ct::Tensor& spsi, + ct::Tensor& grad, + ct::Tensor& hgrad, + ct::Tensor& sgrad, + ct::Tensor& pdir, + ct::Tensor& hpdir, + ct::Tensor& spdir, + ct::Tensor& eigen); + + /// psi = psi * U (via plintrans). + void rotate_wf(const ct::Tensor& hsub_in, + ct::Tensor& psi_out, + ct::Tensor& workspace_in); + + /// S=I Cholesky: psi^H*psi → potrf(U) → trtri → psi *= U^{-1}, hpsi *= U^{-1}. + void orth_cholesky(ct::Tensor& workspace_in, + ct::Tensor& psi_out, + ct::Tensor& hpsi_out, + ct::Tensor& hsub_out); + + /// S-Cholesky orthonormalization. + void orth_cholesky_s(ct::Tensor& workspace_in, + ct::Tensor& psi_out, + ct::Tensor& hpsi_out, + ct::Tensor& spsi_out, + ct::Tensor& hsub_out); + + bool test_error(const ct::Tensor& err_in, const std::vector& ethr_band); + + void validate_ethr_band(const std::vector& ethr_band) const; + + void diag_log(const std::string& context, + const std::string& line1, + const std::string& line2, + const std::string& line3 = std::string()) const; + + void report_not_converged(const char* problem_type, + const int used_iter, + const int max_iter, + const std::vector& ethr_band) const; + + bool profile_enabled() const; + void profile_log(const char* problem_type, + const char* stage, + const int iter, + const double seconds) const; + + int local_band_start() const; + + // ---- memory-op aliases ---- + using ct_Device = typename ct::PsiToContainer::type; + using setmem_var_op = ct::kernels::set_memory; + using resmem_var_op = ct::kernels::resize_memory; + using delmem_var_op = ct::kernels::delete_memory; + using syncmem_var_h2d_op = ct::kernels::synchronize_memory; + using syncmem_var_d2h_op = ct::kernels::synchronize_memory; + + using setmem_complex_op = ct::kernels::set_memory; + using delmem_complex_op = ct::kernels::delete_memory; + using resmem_complex_op = ct::kernels::resize_memory; + using syncmem_complex_op = ct::kernels::synchronize_memory; + using syncmem_complex_h2d_op = ct::kernels::synchronize_memory; + using syncmem_complex_d2h_op = ct::kernels::synchronize_memory; + + /// Internal standard-problem path retained as an implementation detail. + void diag(const HPsiFunc& hpsi_func, + T* psi_in, + Real* eigenvalue_in, + const std::vector& ethr_band); +}; + +} // namespace hsolver +#endif // DIAGO_LOBPCG_H_ diff --git a/source/source_hsolver/hsolver_pw.cpp b/source/source_hsolver/hsolver_pw.cpp index b88bc3b90dd..d3ad4771693 100644 --- a/source/source_hsolver/hsolver_pw.cpp +++ b/source/source_hsolver/hsolver_pw.cpp @@ -11,6 +11,7 @@ #include "source_hsolver/diago_cg.h" #include "source_hsolver/diago_dav_subspace.h" #include "source_hsolver/diago_david.h" +#include "source_hsolver/diago_lobpcg.h" #include "source_hsolver/diago_iter_assist.h" #include "source_io/module_parameter/parameter.h" #include "source_psi/psi.h" @@ -18,11 +19,64 @@ #include +#include +#include #include namespace hsolver { +template +typename std::enable_if>::value + && std::is_same::value, + void>::type +run_lobpcg_pw(const HPsiFunc& hpsi_func, + const SPsiFunc& spsi_func, + psi::Psi& psi, + const std::vector::type>& pre_condition, + typename GetTypeReal::type* eigenvalue, + const std::vector& ethr_band, + const int diag_iter_max, + const int notconv_max, + const int nk_nums) +{ + const int nband_l = psi.get_nbands(); + const int nbasis = psi.get_nbasis(); + const int ndim = psi.get_current_ngk(); + DiagoLobpcg lobpcg(pre_condition.data()); + lobpcg.init_iter(PARAM.inp.nbands, nband_l, nbasis, ndim); + lobpcg.set_max_iter(diag_iter_max); + lobpcg.set_notconv_max(notconv_max); + std::ostringstream context; + context << "k=" << psi.get_current_k() + 1 << "/" << nk_nums + << ", npw=" << ndim + << ", npwx=" << nbasis + << ", nbands=" << PARAM.inp.nbands + << ", nbands_local=" << nband_l + << ", max_iter=" << diag_iter_max + << ", use_uspp=" << (PARAM.globalv.use_uspp ? 1 : 0); + lobpcg.set_diag_context(context.str()); + lobpcg.diag(hpsi_func, spsi_func, psi.get_pointer(), eigenvalue, ethr_band); +} + +template +typename std::enable_if>::value + || !std::is_same::value, + void>::type +run_lobpcg_pw(const HPsiFunc&, + const SPsiFunc&, + psi::Psi&, + const std::vector::type>&, + typename GetTypeReal::type*, + const std::vector&, + const int, + const int, + const int) +{ + ModuleBase::WARNING_QUIT("HSolverPW", + "LOBPCG is currently implemented only for CPU complex PW calculations."); +} + template void HSolverPW::cal_smooth_ethr(const double& wk, const double* wg, @@ -83,7 +137,7 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, this->nproc_in_pool = nproc_in_pool_in; // report if the specified diagonalization method is not supported - const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; + const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg", "lobpcg"}; if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) { ModuleBase::WARNING_QUIT("HSolverPW::solve", "This type of eigensolver is not supported!"); @@ -273,8 +327,8 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, hpsi_info info(&psi_wrapper, bands_range, hpsi_out); hm->ops->hPsi(info); }; - auto spsi_func = [hm](const T* psi_in, T* spsi_out, const int ld_psi, const int nvec) { - hm->sPsi(psi_in, spsi_out, ld_psi, ld_psi, nvec); + auto spsi_func = [hm, cur_nbasis](const T* psi_in, T* spsi_out, const int ld_psi, const int nvec) { + hm->sPsi(psi_in, spsi_out, ld_psi, cur_nbasis, nvec); }; if (this->method == "cg") @@ -323,6 +377,18 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, bpcg.init_iter(PARAM.inp.nbands, nband_l, nbasis, ndim); bpcg.diag(hpsi_func, psi.get_pointer(), eigenvalue, this->ethr_band); } + else if (this->method == "lobpcg") + { + run_lobpcg_pw(hpsi_func, + spsi_func, + psi, + pre_condition, + eigenvalue, + this->ethr_band, + this->diag_iter_max, + ("nscf" == this->calculation_type) ? 0 : -1, + nk_nums); + } else if (this->method == "dav_subspace") { bool scf = this->calculation_type == "nscf" ? false : true; diff --git a/source/source_hsolver/para_linear_transform.cpp b/source/source_hsolver/para_linear_transform.cpp index 1ddcdb78591..6da4ce74135 100644 --- a/source/source_hsolver/para_linear_transform.cpp +++ b/source/source_hsolver/para_linear_transform.cpp @@ -84,7 +84,6 @@ void PLinearTransform::act(const T alpha, const T* A, const T* U, con #ifdef __MPI if (nproc_col > 1) { - syncmem_dev_op()(B_tmp_, B, ncolB * LDA); std::vector requests(nproc_col); // Send for (int ip = 0; ip < nproc_col; ++ip) @@ -180,4 +179,4 @@ template struct PLinearTransform; template struct PLinearTransform, base_device::DEVICE_GPU>; template struct PLinearTransform, base_device::DEVICE_GPU>; #endif -} // namespace hsolver \ No newline at end of file +} // namespace hsolver diff --git a/source/source_hsolver/test/CMakeLists.txt b/source/source_hsolver/test/CMakeLists.txt index 1b1529adb4a..45436b0a26d 100644 --- a/source/source_hsolver/test/CMakeLists.txt +++ b/source/source_hsolver/test/CMakeLists.txt @@ -11,7 +11,15 @@ if (ENABLE_MPI) AddTest( TARGET MODULE_HSOLVER_bpcg LIBS parameter ${math_libs} base psi device container - SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../para_linear_transform.cpp ../diago_iter_assist.cpp + SOURCES diago_bpcg_test.cpp ../diago_bpcg.cpp ../para_linear_transform.cpp ../diago_iter_assist.cpp + ../../source_basis/module_pw/test/test_tool.cpp + ../../source_hamilt/operator.cpp + ../../source_pw/module_pwdft/op_pw.cpp + ) + AddTest( + TARGET MODULE_HSOLVER_lobpcg + LIBS parameter ${math_libs} base psi device container + SOURCES diago_lobpcg_test.cpp ../diago_lobpcg.cpp ../para_linear_transform.cpp ../diago_iter_assist.cpp ../../source_basis/module_pw/test/test_tool.cpp ../../source_hamilt/operator.cpp ../../source_pw/module_pwdft/op_pw.cpp @@ -76,14 +84,14 @@ if (ENABLE_MPI) AddTest( TARGET MODULE_HSOLVER_pw LIBS parameter ${math_libs} psi device base container - SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp + SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_lobpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ../../source_estate/elecstate_tools.cpp ../../source_estate/occupy.cpp ../../source_base/module_fft/fft_bundle.cpp ../../source_base/module_fft/fft_cpu.cpp ) AddTest( TARGET MODULE_HSOLVER_sdft LIBS parameter ${math_libs} psi device base container - SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp + SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_lobpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp ../para_linear_transform.cpp ../../source_estate/elecstate_tools.cpp ../../source_estate/occupy.cpp ../../source_base/module_fft/fft_bundle.cpp ../../source_base/module_fft/fft_cpu.cpp ) @@ -137,6 +145,7 @@ install(FILES KPoints-Si64-Solution.dat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES diago_cg_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES diago_david_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) +install(FILES diago_lobpcg_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES diago_lcao_parallel_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) install(FILES PEXSI-H-GammaOnly-Si2.dat DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) @@ -185,6 +194,10 @@ if (ENABLE_MPI) COMMAND ${BASH} diago_david_parallel_test.sh WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) + add_test(NAME MODULE_HSOLVER_lobpcg_parallel + COMMAND ${BASH} ${CMAKE_CURRENT_SOURCE_DIR}/diago_lobpcg_parallel_test.sh + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + ) if(ENABLE_LCAO) add_test(NAME MODULE_HSOLVER_LCAO_parallel COMMAND ${BASH} diago_lcao_parallel_test.sh @@ -197,4 +210,4 @@ if (ENABLE_MPI) ) endif() endif() -endif() \ No newline at end of file +endif() diff --git a/source/source_hsolver/test/diago_lobpcg_parallel_test.sh b/source/source_hsolver/test/diago_lobpcg_parallel_test.sh new file mode 100644 index 00000000000..2b74d0a3bbe --- /dev/null +++ b/source/source_hsolver/test/diago_lobpcg_parallel_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +np=`cat /proc/cpuinfo | grep "cpu cores" | uniq | awk '{print $NF}'` +echo "nprocs in this machine is $np" + +if [[ 2 -gt $np ]]; then + echo "skip LOBPCG band-parallel UT: fewer than 2 cpu cores" + exit 0 +fi + +echo "TEST DIAGO LOBPCG in band parallel, nprocs=2" +ABACUS_LOBPCG_TEST_BNDPAR=1 OMP_NUM_THREADS=1 mpirun -np 2 ./MODULE_HSOLVER_lobpcg \ + --gtest_filter=DiagoLobpcgTest.GeneralizedBandParallelRankCompressedSubspace:DiagoLobpcgTest.BandParallelReusesProjectedSearchDirectionProducts +e1=$? +if [[ e1 -ne 0 ]]; then + echo -e "\e[1;33m [ FAILED ] \e[0m execute LOBPCG band-parallel UT with 2 cores error." + exit 1 +fi diff --git a/source/source_hsolver/test/diago_lobpcg_test.cpp b/source/source_hsolver/test/diago_lobpcg_test.cpp new file mode 100644 index 00000000000..ebc2067e001 --- /dev/null +++ b/source/source_hsolver/test/diago_lobpcg_test.cpp @@ -0,0 +1,962 @@ +#include "../diago_iter_assist.h" +#include "../diago_lobpcg.h" +#include "source_base/global_variable.h" +#include "source_base/module_external/lapack_connector.h" +#include "source_base/parallel_comm.h" +#include "source_basis/module_pw/test/test_tool.h" + +#ifdef __MPI +#include "mpi.h" +#endif + +#include +#include +#include +#include +#include +#include +#include + +/************************************************ + * unit test of DiagoLobpcg (NC, CPU-only) + * + * Validates eigenvalues, orthonormality, and + * residual against LAPACK zheev for random + * well-conditioned Hermitian matrices. + ***********************************************/ + +using TestT = std::complex; +using TestDevice = base_device::DEVICE_CPU; +using TestReal = double; + +/// Reference eigenvalues via LAPACK zheev (eigenvalues only). +static int lapackEigen(int npw, std::vector& hm, TestReal* e) +{ + int lwork = std::max(1, 2 * npw - 1); + std::vector work(lwork); + std::vector rwork(std::max(1, 3 * npw - 2)); + int info = 0; + char jobz = 'N', uplo = 'U'; + zheev_(&jobz, &uplo, &npw, hm.data(), &npw, e, + work.data(), &lwork, rwork.data(), &info); + return info; +} + +/// Reference generalized eigenvalues via LAPACK zhegvd (eigenvalues only). +static int lapackGeneralizedEigen(int npw, + std::vector& hm, + std::vector& sm, + TestReal* e) +{ + int info = 0; + int itype = 1; + char jobz = 'N', uplo = 'U'; + int lwork = -1, lrwork = -1, liwork = -1; + TestT work_query = {0.0, 0.0}; + TestReal rwork_query = 0.0; + int iwork_query = 0; + zhegvd_(&itype, &jobz, &uplo, &npw, hm.data(), &npw, sm.data(), &npw, e, + &work_query, &lwork, &rwork_query, &lrwork, &iwork_query, &liwork, &info); + if (info != 0) + return info; + + lwork = std::max(1, static_cast(std::real(work_query))); + lrwork = std::max(1, static_cast(rwork_query)); + liwork = std::max(1, iwork_query); + std::vector work(lwork); + std::vector rwork(lrwork); + std::vector iwork(liwork); + zhegvd_(&itype, &jobz, &uplo, &npw, hm.data(), &npw, sm.data(), &npw, e, + work.data(), &lwork, rwork.data(), &lrwork, iwork.data(), &liwork, &info); + return info; +} + +class DiagoLobpcgTest : public ::testing::Test +{ + protected: + struct HegvdMetrics + { + int info = 0; + TestReal max_rel_residual = 0.0; + TestReal max_s_orth_error = 0.0; + TestReal max_abs_coeff = 0.0; + }; + + static int idx(int row, int col, int ld) { return col * ld + row; } + + static HegvdMetrics solve_generalized_eigenvectors( + int dim, + std::vector hmat, + std::vector smat) + { + const auto h_orig = hmat; + const auto s_orig = smat; + std::vector eval(dim, 0.0); + + HegvdMetrics metrics; + int itype = 1; + char jobz = 'V', uplo = 'U'; + int lwork = -1, lrwork = -1, liwork = -1; + TestT work_query = {0.0, 0.0}; + TestReal rwork_query = 0.0; + int iwork_query = 0; + zhegvd_(&itype, &jobz, &uplo, &dim, hmat.data(), &dim, smat.data(), &dim, + eval.data(), &work_query, &lwork, &rwork_query, &lrwork, + &iwork_query, &liwork, &metrics.info); + if (metrics.info != 0) + return metrics; + + lwork = std::max(1, static_cast(std::real(work_query))); + lrwork = std::max(1, static_cast(rwork_query)); + liwork = std::max(1, iwork_query); + std::vector work(lwork); + std::vector rwork(lrwork); + std::vector iwork(liwork); + zhegvd_(&itype, &jobz, &uplo, &dim, hmat.data(), &dim, smat.data(), &dim, + eval.data(), work.data(), &lwork, rwork.data(), &lrwork, + iwork.data(), &liwork, &metrics.info); + if (metrics.info != 0) + return metrics; + + std::vector av(dim), bv(dim); + for (int iv = 0; iv < dim; ++iv) + { + const TestT* vec = hmat.data() + iv * dim; + for (int i = 0; i < dim; ++i) + { + metrics.max_abs_coeff = std::max( + metrics.max_abs_coeff, + static_cast(std::abs(vec[i]))); + + TestT ah = {0.0, 0.0}; + TestT bs = {0.0, 0.0}; + for (int j = 0; j < dim; ++j) + { + ah += h_orig[idx(i, j, dim)] * vec[j]; + bs += s_orig[idx(i, j, dim)] * vec[j]; + } + av[i] = ah; + bv[i] = bs; + } + + TestReal res2 = 0.0; + TestReal av2 = 0.0; + TestReal bv2 = 0.0; + for (int i = 0; i < dim; ++i) + { + res2 += std::norm(av[i] - eval[iv] * bv[i]); + av2 += std::norm(av[i]); + bv2 += std::norm(bv[i]); + } + const TestReal denom = std::max( + static_cast(1.0), + std::sqrt(av2) + std::abs(eval[iv]) * std::sqrt(bv2)); + metrics.max_rel_residual = std::max(metrics.max_rel_residual, + std::sqrt(res2) / denom); + + for (int jv = 0; jv < dim; ++jv) + { + const TestT* vec_j = hmat.data() + jv * dim; + TestT dot = {0.0, 0.0}; + for (int i = 0; i < dim; ++i) + { + TestT bvi = {0.0, 0.0}; + for (int j = 0; j < dim; ++j) + bvi += s_orig[idx(i, j, dim)] * vec_j[j]; + dot += std::conj(vec[i]) * bvi; + } + const TestT target = (iv == jv) ? TestT(1.0, 0.0) : TestT(0.0, 0.0); + metrics.max_s_orth_error = std::max( + metrics.max_s_orth_error, + static_cast(std::abs(dot - target))); + } + } + return metrics; + } + + static void build_nearly_dependent_overlap_problem( + int dim, + TestReal delta, + std::vector& hmat, + std::vector& smat, + std::vector& prec, + std::vector& e_ref) + { + build_well_conditioned(dim, hmat, prec, e_ref); + smat.assign(dim * dim, {0.0, 0.0}); + for (int i = 0; i < dim; ++i) + smat[idx(i, i, dim)] = {1.0, 0.0}; + smat[idx(0, 1, dim)] = {1.0 - delta, 0.0}; + smat[idx(1, 0, dim)] = {1.0 - delta, 0.0}; + + auto hcopy = hmat; + auto scopy = smat; + ASSERT_EQ(lapackGeneralizedEigen(dim, hcopy, scopy, e_ref.data()), 0); + } + + void run_and_validate(int npw, int nband, + const std::vector& hmat, + const std::vector& prec, + const std::vector& e_ref, + double eig_tol, double orth_tol, double res_tol) + { + const int ld = npw; + + // ---- random orthonormal initial guess ---- + std::vector psi(nband * npw, {0.0, 0.0}); + { + std::mt19937 gen(123); + std::uniform_real_distribution dist(-1.0, 1.0); + for (int ib = 0; ib < nband; ib++) + for (int ig = 0; ig < npw; ig++) + psi[ib * npw + ig] = {dist(gen), dist(gen)}; + + for (int ib = 0; ib < nband; ib++) + { + for (int jb = 0; jb < ib; jb++) + { + TestT dot = {0.0, 0.0}; + for (int ig = 0; ig < npw; ig++) + dot += std::conj(psi[jb * npw + ig]) * psi[ib * npw + ig]; + for (int ig = 0; ig < npw; ig++) + psi[ib * npw + ig] -= dot * psi[jb * npw + ig]; + } + TestReal norm = 0.0; + for (int ig = 0; ig < npw; ig++) + norm += std::norm(psi[ib * npw + ig]); + norm = std::sqrt(norm); + for (int ig = 0; ig < npw; ig++) + psi[ib * npw + ig] /= norm; + } + } + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_psi, int nvec) { + for (int iv = 0; iv < nvec; iv++) + for (int i = 0; i < npw; i++) + { + TestT sum = {0.0, 0.0}; + for (int j = 0; j < npw; j++) + sum += hmat[idx(i, j, ld)] * psi_in[iv * ld_psi + j]; + hpsi_out[iv * ld_psi + i] = sum; + } + }; + auto spsi_func = [&](const TestT* psi_in, TestT* spsi_out, + int ld_psi, int nvec) { + for (int iv = 0; iv < nvec; iv++) + { + for (int i = 0; i < npw; i++) + spsi_out[iv * ld_psi + i] = psi_in[iv * ld_psi + i]; + for (int i = npw; i < ld_psi; i++) + spsi_out[iv * ld_psi + i] = {0.0, 0.0}; + } + }; + + // ---- run LOBPCG ---- + std::vector eigens(nband, 0.0); + std::vector ethr(nband, 1e-6); + + // SCF_ITER = 1 triggers periodic R-R restart that clears P, which + // naturally limits subspace ill-conditioning. Use moderate nline. + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + hsolver::DiagoIterAssist::SCF_ITER = 1; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband, npw, npw); + lobpcg.set_nline(4); + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; + + // ---- validate eigenvalues ---- + for (int ib = 0; ib < nband; ib++) + ASSERT_NEAR(eigens[ib], e_ref[ib], eig_tol) + << "eigenvalue[" << ib << "] mismatch: " + << eigens[ib] << " vs ref " << e_ref[ib]; + + // ---- validate orthonormality ---- + for (int i = 0; i < nband; i++) + for (int j = 0; j < nband; j++) + { + TestT dot = {0.0, 0.0}; + for (int ig = 0; ig < npw; ig++) + dot += std::conj(psi[i * npw + ig]) * psi[j * npw + ig]; + if (i == j) + { + EXPECT_NEAR(std::real(dot), 1.0, orth_tol) + << "psi^H psi diag[" << i << "] = " << std::real(dot); + EXPECT_NEAR(std::imag(dot), 0.0, orth_tol) + << "psi^H psi diag[" << i << "] imag = " << std::imag(dot); + } + else + EXPECT_NEAR(std::abs(dot), 0.0, orth_tol) + << "psi^H psi (" << i << "," << j << ") = " << std::abs(dot); + } + + // ---- validate residual: ||H*psi_i - eig_i * psi_i|| ---- + for (int ib = 0; ib < nband; ib++) + { + TestReal res2 = 0.0; + for (int i = 0; i < npw; i++) + { + TestT hxi = {0.0, 0.0}; + for (int j = 0; j < npw; j++) + hxi += hmat[idx(i, j, ld)] * psi[ib * npw + j]; + const auto r = hxi - eigens[ib] * psi[ib * npw + i]; + res2 += std::norm(r); + } + EXPECT_LT(std::sqrt(res2), res_tol) + << "residual[" << ib << "] = " << std::sqrt(res2); + } + } + + /// Build a strongly diagonally-dominant random Hermitian matrix. + static void build_well_conditioned(int npw, std::vector& hmat, + std::vector& prec, + std::vector& e_ref) + { + const int ld = npw; + hmat.assign(npw * npw, {0.0, 0.0}); + + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0, 1.0); + + for (int i = 0; i < npw; i++) + { + for (int j = i; j < npw; j++) + { + TestReal re = dist(gen) * 0.5; + TestReal im = (i != j) ? dist(gen) * 0.5 : 0.0; + hmat[idx(i, j, ld)] = {re, im}; + hmat[idx(j, i, ld)] = {re, -im}; + } + hmat[idx(i, i, ld)] += TestT( + static_cast(2.0 * (i + 1) * (i + 1)), 0.0); + } + + // Reference + auto hcopy = hmat; + e_ref.resize(npw); + ASSERT_EQ(lapackEigen(npw, hcopy, e_ref.data()), 0); + + // Preconditioner + prec.resize(npw); + for (int i = 0; i < npw; i++) + prec[i] = std::max(static_cast(1.0), + std::real(hmat[idx(i, i, ld)])); + } + + static void matvec(const std::vector& mat, + const TestT* psi_in, + TestT* out, + int npw, + int ld_psi, + int nvec) + { + const int ld = npw; + for (int iv = 0; iv < nvec; iv++) + { + for (int i = 0; i < npw; i++) + { + TestT sum = {0.0, 0.0}; + for (int j = 0; j < npw; j++) + sum += mat[idx(i, j, ld)] * psi_in[iv * ld_psi + j]; + out[iv * ld_psi + i] = sum; + } + for (int i = npw; i < ld_psi; i++) + out[iv * ld_psi + i] = {123.0, -456.0}; + } + } + + static void build_generalized_problem(int npw, + TestReal overlap_diag, + TestReal overlap_scale, + int seed, + std::vector& hmat, + std::vector& smat, + std::vector& prec, + std::vector& e_ref) + { + build_well_conditioned(npw, hmat, prec, e_ref); + smat.assign(npw * npw, {0.0, 0.0}); + + std::mt19937 gen(seed); + std::uniform_real_distribution dist(-1.0, 1.0); + for (int i = 0; i < npw; i++) + { + for (int j = i; j < npw; j++) + { + TestReal re = dist(gen) * overlap_scale; + TestReal im = (i != j) ? dist(gen) * overlap_scale : 0.0; + smat[idx(i, j, npw)] = {re, im}; + smat[idx(j, i, npw)] = {re, -im}; + } + smat[idx(i, i, npw)] += TestT(overlap_diag, 0.0); + } + + auto hcopy = hmat; + auto scopy = smat; + e_ref.resize(npw); + ASSERT_EQ(lapackGeneralizedEigen(npw, hcopy, scopy, e_ref.data()), 0); + } + + void run_generalized_and_validate(int npw, + int nband, + int ld_psi, + TestReal overlap_diag, + TestReal overlap_scale, + int seed, + TestReal min_s_minus_identity, + double eig_tol, + double orth_tol, + double res_tol) + { + std::vector hmat, smat; + std::vector prec, e_ref; + build_generalized_problem(npw, overlap_diag, overlap_scale, seed, + hmat, smat, prec, e_ref); + + TestReal max_s_minus_identity = 0.0; + for (int j = 0; j < npw; j++) + for (int i = 0; i < npw; i++) + { + const TestT identity = (i == j) ? TestT(1.0, 0.0) : TestT(0.0, 0.0); + max_s_minus_identity = std::max( + max_s_minus_identity, + static_cast(std::abs(smat[idx(i, j, npw)] - identity))); + } + ASSERT_GT(max_s_minus_identity, min_s_minus_identity); + + std::vector psi(nband * ld_psi, {0.0, 0.0}); + for (int ib = 0; ib < nband; ib++) + { + psi[ib * ld_psi + ib] = {1.0, 0.0}; + for (int ig = npw; ig < ld_psi; ig++) + psi[ib * ld_psi + ig] = {99.0, -77.0}; + } + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_in, int nvec) { + matvec(hmat, psi_in, hpsi_out, npw, ld_in, nvec); + }; + auto spsi_func = [&](const TestT* psi_in, TestT* spsi_out, + int ld_in, int nvec) { + matvec(smat, psi_in, spsi_out, npw, ld_in, nvec); + }; + + std::vector eigens(nband, 0.0); + std::vector ethr(nband, res_tol); + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + hsolver::DiagoIterAssist::SCF_ITER = 1; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband, ld_psi, npw); + lobpcg.set_nline(10); + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; + + for (int ib = 0; ib < nband; ib++) + ASSERT_NEAR(eigens[ib], e_ref[ib], eig_tol); + + std::vector hpsi(nband * ld_psi), spsi(nband * ld_psi); + matvec(hmat, psi.data(), hpsi.data(), npw, ld_psi, nband); + matvec(smat, psi.data(), spsi.data(), npw, ld_psi, nband); + for (int i = 0; i < nband; i++) + { + for (int ig = npw; ig < ld_psi; ig++) + EXPECT_EQ(psi[i * ld_psi + ig], TestT(0.0, 0.0)); + + for (int j = 0; j < nband; j++) + { + TestT dot = {0.0, 0.0}; + for (int ig = 0; ig < npw; ig++) + dot += std::conj(psi[i * ld_psi + ig]) * spsi[j * ld_psi + ig]; + EXPECT_NEAR( + std::abs(dot - (i == j ? TestT(1.0, 0.0) : TestT(0.0, 0.0))), + 0.0, orth_tol); + } + + TestReal res2 = 0.0; + for (int ig = 0; ig < npw; ig++) + res2 += std::norm(hpsi[i * ld_psi + ig] - eigens[i] * spsi[i * ld_psi + ig]); + EXPECT_LT(std::sqrt(res2), res_tol); + } + } + +#ifdef __MPI + void run_generalized_band_parallel_and_validate() + { + int nproc = 1; + int rank = 0; + MPI_Comm_size(MPI_COMM_WORLD, &nproc); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + if (nproc < 2) + GTEST_SKIP() << "band-parallel LOBPCG test requires at least 2 MPI ranks"; + + const int npw = 18; + const int nband = 10; + const int ld_psi = npw + 3; + const int nband_l = nband / nproc + (rank < nband % nproc ? 1 : 0); + const int band_start = nband / nproc * rank + std::min(rank, nband % nproc); + + std::vector hmat, smat; + std::vector prec, e_ref; + build_generalized_problem(npw, 1.7, 0.03, 101, hmat, smat, prec, e_ref); + + std::vector psi(nband_l * ld_psi, {0.0, 0.0}); + for (int ib = 0; ib < nband_l; ++ib) + { + const int global_band = band_start + ib; + psi[ib * ld_psi + global_band] = {1.0, 0.0}; + for (int ig = npw; ig < ld_psi; ++ig) + psi[ib * ld_psi + ig] = {99.0, -77.0}; + } + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_in, int nvec) { + matvec(hmat, psi_in, hpsi_out, npw, ld_in, nvec); + }; + auto spsi_func = [&](const TestT* psi_in, TestT* spsi_out, + int ld_in, int nvec) { + matvec(smat, psi_in, spsi_out, npw, ld_in, nvec); + }; + + std::vector eigens(nband_l, 0.0); + std::vector ethr(nband_l, 1e-8); + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + hsolver::DiagoIterAssist::SCF_ITER = 1; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband_l, ld_psi, npw); + lobpcg.set_nline(10); + lobpcg.set_max_iter(80); + lobpcg.set_diag_context("parallel-unit-rank-compression"); + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; + + for (int ib = 0; ib < nband_l; ++ib) + { + const int global_band = band_start + ib; + ASSERT_NEAR(eigens[ib], e_ref[global_band], 2e-5) + << "global_band=" << global_band; + } + + std::vector hpsi(nband_l * ld_psi), spsi(nband_l * ld_psi); + matvec(hmat, psi.data(), hpsi.data(), npw, ld_psi, nband_l); + matvec(smat, psi.data(), spsi.data(), npw, ld_psi, nband_l); + + for (int ib = 0; ib < nband_l; ++ib) + { + const int global_band = band_start + ib; + for (int ig = npw; ig < ld_psi; ++ig) + EXPECT_EQ(psi[ib * ld_psi + ig], TestT(0.0, 0.0)); + + TestReal res2 = 0.0; + for (int ig = 0; ig < npw; ++ig) + res2 += std::norm(hpsi[ib * ld_psi + ig] + - eigens[ib] * spsi[ib * ld_psi + ig]); + EXPECT_LT(std::sqrt(res2), 2e-4) + << "global_band=" << global_band; + } + + std::vector global_psi(nband * ld_psi, {0.0, 0.0}); + std::vector counts(nproc, 0), displs(nproc, 0); + for (int ip = 0; ip < nproc; ++ip) + { + const int nlocal = nband / nproc + (ip < nband % nproc ? 1 : 0); + counts[ip] = nlocal * ld_psi; + displs[ip] = (nband / nproc * ip + std::min(ip, nband % nproc)) * ld_psi; + } + MPI_Allgatherv(psi.data(), nband_l * ld_psi, MPI_DOUBLE_COMPLEX, + global_psi.data(), counts.data(), displs.data(), + MPI_DOUBLE_COMPLEX, MPI_COMM_WORLD); + + std::vector global_spsi(nband * ld_psi, {0.0, 0.0}); + matvec(smat, global_psi.data(), global_spsi.data(), npw, ld_psi, nband); + for (int i = 0; i < nband; ++i) + { + for (int j = 0; j < nband; ++j) + { + TestT dot = {0.0, 0.0}; + for (int ig = 0; ig < npw; ++ig) + dot += std::conj(global_psi[i * ld_psi + ig]) + * global_spsi[j * ld_psi + ig]; + const TestT target = (i == j) ? TestT(1.0, 0.0) : TestT(0.0, 0.0); + EXPECT_NEAR(std::abs(dot - target), 0.0, 2e-6) + << "S-orth(" << i << "," << j << ")"; + } + } + } + + void run_generalized_band_parallel_operator_count() + { + int nproc = 1; + int rank = 0; + MPI_Comm_size(MPI_COMM_WORLD, &nproc); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + if (nproc < 2) + GTEST_SKIP() << "band-parallel LOBPCG test requires at least 2 MPI ranks"; + + const int npw = 18; + const int nband = 10; + const int ld_psi = npw + 3; + const int nband_l = nband / nproc + (rank < nband % nproc ? 1 : 0); + const int band_start = nband / nproc * rank + std::min(rank, nband % nproc); + + std::vector hmat, smat; + std::vector prec, e_ref; + build_generalized_problem(npw, 1.7, 0.03, 103, hmat, smat, prec, e_ref); + + std::vector psi(nband_l * ld_psi, {0.0, 0.0}); + for (int ib = 0; ib < nband_l; ++ib) + { + const int global_band = band_start + ib; + psi[ib * ld_psi + global_band] = {1.0, 0.0}; + } + + int hpsi_calls = 0; + int spsi_calls = 0; + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_in, int nvec) { + ++hpsi_calls; + matvec(hmat, psi_in, hpsi_out, npw, ld_in, nvec); + }; + auto spsi_func = [&](const TestT* psi_in, TestT* spsi_out, + int ld_in, int nvec) { + ++spsi_calls; + matvec(smat, psi_in, spsi_out, npw, ld_in, nvec); + }; + + std::vector eigens(nband_l, 0.0); + std::vector ethr(nband_l, 0.0); + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + hsolver::DiagoIterAssist::SCF_ITER = 1; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband_l, ld_psi, npw); + lobpcg.set_max_iter(2); + lobpcg.set_diag_context("parallel-unit-operator-count"); + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; + + EXPECT_EQ(hpsi_calls, 2); + EXPECT_EQ(spsi_calls, 2); + } + +#endif +}; + +// ============================================================================ +// Test cases: various matrix sizes and band counts +// ============================================================================ + +TEST_F(DiagoLobpcgTest, SmallMatrix) +{ + const int npw = 50, nband = 10; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + run_and_validate(npw, nband, hmat, prec, e_ref, 1e-5, 1e-8, 1e-4); +} + +TEST_F(DiagoLobpcgTest, MediumMatrix) +{ + const int npw = 200, nband = 20; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + run_and_validate(npw, nband, hmat, prec, e_ref, 1e-4, 1e-6, 2e-4); +} + +TEST_F(DiagoLobpcgTest, LargerMatrixFewBands) +{ + const int npw = 400, nband = 12; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + run_and_validate(npw, nband, hmat, prec, e_ref, 1e-4, 1e-6, 3e-4); +} + +TEST_F(DiagoLobpcgTest, ReportsUnconvergedAtMaxIter) +{ + const int npw = 50, nband = 10; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + + std::vector psi(nband * npw, {0.0, 0.0}); + for (int ib = 0; ib < nband; ++ib) + psi[ib * npw + ib] = {1.0, 0.0}; + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_psi, int nvec) { + for (int iv = 0; iv < nvec; ++iv) + for (int i = 0; i < npw; ++i) + { + TestT sum = {0.0, 0.0}; + for (int j = 0; j < npw; ++j) + sum += hmat[idx(i, j, npw)] * psi_in[iv * ld_psi + j]; + hpsi_out[iv * ld_psi + i] = sum; + } + }; + auto spsi_func = [](const TestT* psi_in, TestT* spsi_out, + int ld_psi, int nvec) { + std::copy(psi_in, psi_in + nvec * ld_psi, spsi_out); + }; + + std::vector eigens(nband, 0.0); + std::vector ethr(nband, 0.0); + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + const auto old_avg_iter = hsolver::DiagoIterAssist::avg_iter; + hsolver::DiagoIterAssist::SCF_ITER = 1; + hsolver::DiagoIterAssist::avg_iter = 0.0; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband, npw, npw); + lobpcg.set_max_iter(1); + lobpcg.set_diag_context("k=2/4, npw=50, nbands=10"); + + testing::internal::CaptureStdout(); + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + const std::string output = testing::internal::GetCapturedStdout(); + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; + + EXPECT_NE(output.find("DiagoLobpcg::diag(S=I)"), std::string::npos); + EXPECT_NE(output.find("max_iter=1"), std::string::npos); + EXPECT_NE(output.find("notconv="), std::string::npos); + EXPECT_NE(output.find("context={k=2/4, npw=50, nbands=10}"), std::string::npos); + const auto avg_iter = hsolver::DiagoIterAssist::avg_iter; + EXPECT_EQ(avg_iter, 1.0); + hsolver::DiagoIterAssist::avg_iter = old_avg_iter; +} + +TEST_F(DiagoLobpcgTest, ThrowsWhenNotconvExceedsLimit) +{ + const int npw = 50, nband = 10; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + + std::vector psi(nband * npw, {0.0, 0.0}); + for (int ib = 0; ib < nband; ++ib) + psi[ib * npw + ib] = {1.0, 0.0}; + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_psi, int nvec) { + matvec(hmat, psi_in, hpsi_out, npw, ld_psi, nvec); + }; + auto spsi_func = [](const TestT* psi_in, TestT* spsi_out, + int ld_psi, int nvec) { + std::copy(psi_in, psi_in + nvec * ld_psi, spsi_out); + }; + + std::vector eigens(nband, 0.0); + std::vector ethr(nband, 0.0); + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + const auto old_avg_iter = hsolver::DiagoIterAssist::avg_iter; + hsolver::DiagoIterAssist::SCF_ITER = 1; + hsolver::DiagoIterAssist::avg_iter = 0.0; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband, npw, npw); + lobpcg.set_max_iter(1); + lobpcg.set_notconv_max(0); + lobpcg.set_diag_context("k=3/4, npw=50, nbands=10"); + + testing::internal::CaptureStdout(); + try { + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + FAIL() << "Expected runtime_error when notconv exceeds limit"; + } catch (const std::runtime_error& e) { + const std::string output = testing::internal::GetCapturedStdout(); + const std::string msg = e.what(); + EXPECT_NE(msg.find("notconv="), std::string::npos); + EXPECT_NE(msg.find("context={k=3/4, npw=50, nbands=10}"), std::string::npos); + EXPECT_NE(output.find("notconv="), std::string::npos); + } + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; + hsolver::DiagoIterAssist::avg_iter = old_avg_iter; +} + +TEST_F(DiagoLobpcgTest, RejectsNonLocalEthrBandSize) +{ + const int npw = 20, nband = 6; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + + std::vector psi(nband * npw, {0.0, 0.0}); + for (int ib = 0; ib < nband; ++ib) + psi[ib * npw + ib] = {1.0, 0.0}; + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_psi, int nvec) { + matvec(hmat, psi_in, hpsi_out, npw, ld_psi, nvec); + }; + auto spsi_func = [](const TestT* psi_in, TestT* spsi_out, + int ld_psi, int nvec) { + std::copy(psi_in, psi_in + nvec * ld_psi, spsi_out); + }; + + std::vector eigens(nband, 0.0); + std::vector ethr(nband + 1, 1e-6); + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband, npw, npw); + lobpcg.set_diag_context("k=1/1, npw=20, nbands=6"); + + try { + lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr); + FAIL() << "Expected invalid_argument for non-local ethr_band size"; + } catch (const std::invalid_argument& e) { + const std::string msg = e.what(); + EXPECT_NE(msg.find("local ethr_band size mismatch"), std::string::npos); + EXPECT_NE(msg.find("size=7"), std::string::npos); + EXPECT_NE(msg.find("required local bands=6"), std::string::npos); + EXPECT_NE(msg.find("global bands=6"), std::string::npos); + EXPECT_NE(msg.find("context={k=1/1, npw=20, nbands=6}"), std::string::npos); + } +} + +TEST_F(DiagoLobpcgTest, GeneralizedNonPositiveOverlapFails) +{ + const int npw = 20, nband = 6; + std::vector hmat; + std::vector prec, e_ref; + build_well_conditioned(npw, hmat, prec, e_ref); + + std::vector smat(npw * npw, {0.0, 0.0}); + for (int i = 0; i < npw; ++i) + smat[idx(i, i, npw)] = {-1.0, 0.0}; + + std::vector psi(nband * npw, {0.0, 0.0}); + for (int ib = 0; ib < nband; ++ib) + psi[ib * npw + ib] = {1.0, 0.0}; + + auto hpsi_func = [&](TestT* psi_in, TestT* hpsi_out, + int ld_psi, int nvec) { + matvec(hmat, psi_in, hpsi_out, npw, ld_psi, nvec); + }; + auto spsi_func = [&](const TestT* psi_in, TestT* spsi_out, + int ld_psi, int nvec) { + matvec(smat, psi_in, spsi_out, npw, ld_psi, nvec); + }; + + std::vector eigens(nband, 0.0); + std::vector ethr(nband, 1e-6); + const int old_scf = hsolver::DiagoIterAssist::SCF_ITER; + hsolver::DiagoIterAssist::SCF_ITER = 1; + + hsolver::DiagoLobpcg lobpcg(prec.data()); + lobpcg.init_iter(nband, nband, npw, npw); + EXPECT_THROW(lobpcg.diag(hpsi_func, spsi_func, psi.data(), eigens.data(), ethr), + std::runtime_error); + + hsolver::DiagoIterAssist::SCF_ITER = old_scf; +} + +TEST_F(DiagoLobpcgTest, GeneralizedUsppLikeOverlap) +{ + const int npw = 40, nband = 8; + run_generalized_and_validate(npw, nband, npw + 7, + 1.5, 0.02, 7, 0.1, + 1e-5, 1e-7, 1e-4); +} + +TEST_F(DiagoLobpcgTest, GeneralizedNearlyIdentityOverlap) +{ + const int npw = 60, nband = 10; + run_generalized_and_validate(npw, nband, npw, + 1.05, 0.005, 17, 0.01, + 1e-5, 1e-7, 1e-4); +} + +TEST_F(DiagoLobpcgTest, GeneralizedModerateCouplingOverlap) +{ + const int npw = 80, nband = 12; + run_generalized_and_validate(npw, nband, npw + 5, + 2.0, 0.05, 29, 0.5, + 2e-5, 2e-7, 2e-4); +} + +TEST_F(DiagoLobpcgTest, HegvdIsAccurateForWellConditionedOverlap) +{ + const int npw = 40; + std::vector hmat, smat; + std::vector prec, e_ref; + build_generalized_problem(npw, 1.5, 0.02, 7, hmat, smat, prec, e_ref); + + const HegvdMetrics metrics = solve_generalized_eigenvectors(npw, hmat, smat); + EXPECT_EQ(metrics.info, 0); + EXPECT_LT(metrics.max_rel_residual, 1e-12); + EXPECT_LT(metrics.max_s_orth_error, 1e-12); + EXPECT_LT(metrics.max_abs_coeff, 2.0); +} + +TEST_F(DiagoLobpcgTest, HegvdAccuracyDegradesForNearlySingularOverlap) +{ + const int npw = 20; + std::vector hmat, smat; + std::vector prec, e_ref; + build_nearly_dependent_overlap_problem(npw, 1e-10, hmat, smat, prec, e_ref); + + const HegvdMetrics metrics = solve_generalized_eigenvectors(npw, hmat, smat); + EXPECT_EQ(metrics.info, 0); + EXPECT_GT(metrics.max_rel_residual, 1e-10); + EXPECT_LT(metrics.max_rel_residual, 1e-5); + EXPECT_GT(metrics.max_s_orth_error, 1e-8); + EXPECT_GT(metrics.max_abs_coeff, 1e4); +} + +#ifdef __MPI +TEST_F(DiagoLobpcgTest, GeneralizedBandParallelRankCompressedSubspace) +{ + run_generalized_band_parallel_and_validate(); +} + +TEST_F(DiagoLobpcgTest, BandParallelReusesProjectedSearchDirectionProducts) +{ + run_generalized_band_parallel_operator_count(); +} + +#endif + +int main(int argc, char** argv) +{ + int nproc = 1, myrank = 0; + +#ifdef __MPI + int nproc_in_pool, kpar = 1, mypool, rank_in_pool; + setupmpi(argc, argv, nproc, myrank); + divide_pools(nproc, myrank, nproc_in_pool, kpar, mypool, rank_in_pool); + const bool use_band_parallel_world = std::getenv("ABACUS_LOBPCG_TEST_BNDPAR") != nullptr; + MPI_Comm_split(MPI_COMM_WORLD, use_band_parallel_world ? 0 : myrank, 0, &BP_WORLD); + if (use_band_parallel_world) + { + GlobalV::MY_BNDGROUP = myrank; + GlobalV::NPROC_IN_BNDGROUP = nproc; + MPI_Comm_free(&POOL_WORLD); + MPI_Comm_split(MPI_COMM_WORLD, myrank, 0, &POOL_WORLD); + } + GlobalV::NPROC_IN_POOL = nproc; +#endif + + testing::InitGoogleTest(&argc, argv); + ::testing::TestEventListeners& listeners = + ::testing::UnitTest::GetInstance()->listeners(); + if (myrank != 0) + delete listeners.Release(listeners.default_result_printer()); + + int result = RUN_ALL_TESTS(); + if (myrank == 0 && result != 0) + std::cout << "ERROR: some tests are not passed" << std::endl; + +#ifdef __MPI + MPI_Finalize(); +#endif + return result; +} diff --git a/source/source_io/module_parameter/read_input_item_elec_stru.cpp b/source/source_io/module_parameter/read_input_item_elec_stru.cpp index 39f37febc54..7fdb3fa5417 100644 --- a/source/source_io/module_parameter/read_input_item_elec_stru.cpp +++ b/source/source_io/module_parameter/read_input_item_elec_stru.cpp @@ -131,7 +131,7 @@ Then the user has to correct the input file and restart the calculation.)"; }; item.check_value = [](const Input_Item& item, const Parameter& para) { const std::string& ks_solver = para.input.ks_solver; - const std::vector pw_solvers = {"cg", "dav", "bpcg", "dav_subspace"}; + const std::vector pw_solvers = {"cg", "dav", "bpcg", "dav_subspace", "lobpcg"}; const std::vector lcao_solvers = { "genelpa", "elpa", @@ -1040,7 +1040,7 @@ Use case: When experimental or high-level theoretical results suggest that the S item.annotation = "threshold for eigenvalues is cg electron iterations"; item.category = "Plane wave related variables"; item.type = "Real"; - item.description = "Only used when you use ks_solver = cg/dav/dav_subspace/bpcg. It indicates the threshold for the first electronic iteration, from the second iteration the pw_diag_thr will be updated automatically. For nscf calculations with planewave basis set, pw_diag_thr should be <= 1e-3."; + item.description = "Only used when you use ks_solver = cg/dav/dav_subspace/bpcg/lobpcg. It indicates the threshold for the first electronic iteration, from the second iteration the pw_diag_thr will be updated automatically. For nscf calculations with planewave basis set, pw_diag_thr should be <= 1e-3."; item.default_value = "0.01"; item.unit = ""; item.availability = ""; @@ -1099,10 +1099,10 @@ Use case: When experimental or high-level theoretical results suggest that the S item.annotation = "max iteration number for cg"; item.category = "Plane wave related variables"; item.type = "Integer"; - item.description = "Only useful when you use ks_solver = cg/dav/dav_subspace/bpcg. It indicates the maximal iteration number for cg/david/dav_subspace/bpcg method."; + item.description = "Only useful when you use ks_solver = cg/dav/dav_subspace/bpcg/lobpcg. It indicates the maximal iteration number for cg/david/dav_subspace/bpcg/lobpcg method."; item.default_value = "50"; item.unit = ""; - item.availability = "basis_type==pw, ks_solver==cg/dav/dav_subspace/bpcg"; + item.availability = "basis_type==pw, ks_solver==cg/dav/dav_subspace/bpcg/lobpcg"; read_sync_int(input.pw_diag_nmax); this->add_item(item); } diff --git a/source/source_io/module_parameter/read_input_item_system.cpp b/source/source_io/module_parameter/read_input_item_system.cpp index d746d2b5b12..5cc6b2b32e1 100644 --- a/source/source_io/module_parameter/read_input_item_system.cpp +++ b/source/source_io/module_parameter/read_input_item_system.cpp @@ -340,7 +340,9 @@ void ReadInput::item_system() item.default_value = "1"; read_sync_int(input.bndpar); item.reset_value = [](const Input_Item& item, Parameter& para) { - if (para.input.esolver_type != "sdft" && para.input.ks_solver != "bpcg") + if (para.input.esolver_type != "sdft" + && para.input.ks_solver != "bpcg" + && para.input.ks_solver != "lobpcg") { para.input.bndpar = 1; } diff --git a/source/source_io/module_parameter/read_set_globalv.cpp b/source/source_io/module_parameter/read_set_globalv.cpp index 94c0095c27c..a66483df073 100644 --- a/source/source_io/module_parameter/read_set_globalv.cpp +++ b/source/source_io/module_parameter/read_set_globalv.cpp @@ -61,7 +61,7 @@ void ReadInput::set_globalv(const Input_para& inp, System_para& sys) Parallel_Common::bcast_bool(sys.double_grid); #endif /// set ks_run - if (inp.ks_solver != "bpcg" && inp.bndpar > 1) + if (inp.ks_solver != "bpcg" && inp.ks_solver != "lobpcg" && inp.bndpar > 1) { sys.all_ks_run = false; } diff --git a/source/source_io/test_serial/io_system_variable_test.cpp b/source/source_io/test_serial/io_system_variable_test.cpp index 235f01b8551..edbe0af24d9 100644 --- a/source/source_io/test_serial/io_system_variable_test.cpp +++ b/source/source_io/test_serial/io_system_variable_test.cpp @@ -38,6 +38,16 @@ ModuleIO::ReadInput readinput(0); Parameter param; std::string output = ""; +static bool all_ks_run_after_set_globalv(const std::string& ks_solver, const int bndpar) +{ + Parameter local_param; + local_param.input.ks_solver = ks_solver; + local_param.input.bndpar = bndpar; + local_param.sys.all_ks_run = true; + readinput.set_globalv(local_param.inp, local_param.sys); + return local_param.sys.all_ks_run; +} + TEST_F(InputTest, Item_test) { readinput.check_ntype_flag = false; @@ -81,4 +91,12 @@ TEST_F(InputTest, Item_test) } -} \ No newline at end of file +} + +TEST_F(InputTest, BandParallelAllKsRun) +{ + EXPECT_FALSE(all_ks_run_after_set_globalv("cg", 2)); + EXPECT_FALSE(all_ks_run_after_set_globalv("dav", 2)); + EXPECT_TRUE(all_ks_run_after_set_globalv("bpcg", 2)); + EXPECT_TRUE(all_ks_run_after_set_globalv("lobpcg", 2)); +} diff --git a/source/source_psi/psi_prepare.cpp b/source/source_psi/psi_prepare.cpp index adc10eeb3fa..a18b8ee3a6e 100644 --- a/source/source_psi/psi_prepare.cpp +++ b/source/source_psi/psi_prepare.cpp @@ -134,8 +134,12 @@ void PSIPrepare::initialize_psi(Psi>* psi, Psi* psi_cpu = reinterpret_cast*>(psi); Psi* psi_device = kspw_psi; + Psi* bp_global_evc_cpu = nullptr; + Psi* bp_global_evc_device = nullptr; - bool fill = PARAM.inp.ks_solver != "bpcg" || GlobalV::MY_BNDGROUP == 0; + const bool supports_band_parallel = PARAM.inp.ks_solver == "bpcg" || PARAM.inp.ks_solver == "lobpcg"; + const bool lobpcg_band_parallel = PARAM.inp.ks_solver == "lobpcg" && PARAM.inp.bndpar > 1; + bool fill = !supports_band_parallel || GlobalV::MY_BNDGROUP == 0; if (fill) { if (not_equal) @@ -143,6 +147,12 @@ void PSIPrepare::initialize_psi(Psi>* psi, psi_cpu = new Psi(1, nbands_start, nbasis, nbasis, true); psi_device = PARAM.inp.device == "gpu" ? new psi::Psi(psi_cpu[0]) : reinterpret_cast*>(psi_cpu); + if (lobpcg_band_parallel) + { + bp_global_evc_cpu = new Psi(1, PARAM.inp.nbands, nbasis, nbasis, true); + bp_global_evc_device = PARAM.inp.device == "gpu" ? new psi::Psi(bp_global_evc_cpu[0]) + : reinterpret_cast*>(bp_global_evc_cpu); + } } else if (PARAM.inp.precision == "single") { @@ -180,18 +190,19 @@ void PSIPrepare::initialize_psi(Psi>* psi, } - if (this->ks_solver == "cg") + if (this->ks_solver == "cg" || this->ks_solver == "lobpcg") { std::vector::type> etatom(nbands_start, 0.0); if (not_equal) { + psi::Psi& evc_out = bp_global_evc_device == nullptr ? *kspw_psi : *bp_global_evc_device; // for diagH_subspace_init, psi_device->get_pointer() and kspw_psi->get_pointer() should be // different hsolver::DiagoIterAssist::diag_subspace_init(p_hamilt, psi_device->get_pointer(), nbands_start, nbasis, - *(kspw_psi), + evc_out, etatom.data()); } else @@ -213,8 +224,11 @@ void PSIPrepare::initialize_psi(Psi>* psi, } } #ifdef __MPI - if (PARAM.inp.ks_solver == "bpcg" && PARAM.inp.bndpar > 1) + if (supports_band_parallel && PARAM.inp.bndpar > 1) { + const T* scatter_source = (lobpcg_band_parallel && bp_global_evc_device != nullptr) + ? bp_global_evc_device->get_pointer() + : psi_cpu->get_pointer(); std::vector sendcounts(PARAM.inp.bndpar); std::vector displs(PARAM.inp.bndpar); MPI_Allgather(&nbands_l, 1, MPI_INT, sendcounts.data(), 1, MPI_INT, BP_WORLD); @@ -227,9 +241,13 @@ void PSIPrepare::initialize_psi(Psi>* psi, } if (GlobalV::MY_BNDGROUP == 0) { + if (lobpcg_band_parallel && scatter_source != kspw_psi->get_pointer()) + { + syncmem_complex_op()(kspw_psi->get_pointer(), scatter_source, sendcounts[0]); + } for (int ip = 1; ip < PARAM.inp.bndpar; ++ip) { - Parallel_Common::send_data(psi_cpu->get_pointer() + displs[ip], sendcounts[ip], ip, 0, BP_WORLD); + Parallel_Common::send_data(scatter_source + displs[ip], sendcounts[ip], ip, 0, BP_WORLD); } } else @@ -245,6 +263,14 @@ void PSIPrepare::initialize_psi(Psi>* psi, { if (not_equal) { + if (lobpcg_band_parallel && bp_global_evc_cpu != nullptr) + { + if (PARAM.inp.device == "gpu") + { + delete bp_global_evc_device; + } + delete bp_global_evc_cpu; + } delete psi_cpu; if (PARAM.inp.device == "gpu") {