diff --git a/dgamore/nonlocal_sde.py b/dgamore/nonlocal_sde.py index e430f54..670d9cd 100644 --- a/dgamore/nonlocal_sde.py +++ b/dgamore/nonlocal_sde.py @@ -336,16 +336,24 @@ def fit_oz_spin(q_grid: KGrid, mat: np.ndarray): chi_mat = chi.map_to_full_bz(config.lattice.q_grid).to_half_niw_range().take_first_wn().mat.real orb_shape = (config.sys.n_bands,) * 4 oz_coeffs = np.zeros(orb_shape + (2,), dtype=float) + failed_orbitals = [] for idx in np.ndindex(orb_shape): mat_slice = chi_mat[..., idx[0], idx[1], idx[2], idx[3]].flatten() try: coeffs = fit_oz_spin(config.lattice.q_grid, mat_slice) if not np.all(mat_slice == 0) else [0.0, 0.0] except (ValueError, RuntimeError, opt.OptimizeWarning): - config.logger.warning(f"OZ fit did not converge for orbitals {idx}. Using [-1, -1].") + failed_orbitals.append(idx) coeffs = [-1.0, -1.0] oz_coeffs[idx] = coeffs + if failed_orbitals: + one_based = [tuple(o + 1 for o in idx) for idx in failed_orbitals] + config.logger.warning( + f"OZ fit did not converge for {len(failed_orbitals)} orbital combination(s): " + f"{one_based}. Using [-1, -1]." + ) + rows = [] for idx in np.ndindex(orb_shape): rows.append([*idx, *oz_coeffs[idx]]) diff --git a/tests/test_nonlocal_sde.py b/tests/test_nonlocal_sde.py index 8cd87b6..ebb6b67 100644 --- a/tests/test_nonlocal_sde.py +++ b/tests/test_nonlocal_sde.py @@ -6,15 +6,17 @@ import itertools import os +from unittest import mock import numpy as np import dgamore.config as config +import dgamore.nonlocal_sde as nonlocal_sde from dgamore.hamiltonian import Hamiltonian from dgamore.interaction import Interaction from dgamore.local_sde import get_local_hartree_fock from dgamore.n_point_base import SpinChannel -from dgamore.nonlocal_sde import _init_mu_history, get_hartree_fock +from dgamore.nonlocal_sde import _init_mu_history, get_hartree_fock, perform_ornstein_zernicke_fit LOCAL_SDE_DATA = f"{os.path.dirname(os.path.abspath(__file__))}/test_data/local_sde" @@ -80,3 +82,55 @@ def test_nonlocal_hartree_fock_matches_local_reference(): assert np.allclose(hf_nonlocal, sigma_hf_ref[None, ...]) # the same reference is the local Hartree-Fock, so the two SDE paths agree assert np.allclose(hf_nonlocal, get_local_hartree_fock(u_loc, occ)[None, ...]) + + +class _ConstantChi: + """Minimal physical-susceptibility stand-in whose BZ and frequency reductions are identities.""" + + def __init__(self, mat: np.ndarray): + """Stores the orbital-resolved matrix that the reduction chain returns unchanged.""" + self._mat = mat + + def map_to_full_bz(self, grid): + """Identity unfolding to the full BZ.""" + return self + + def to_half_niw_range(self): + """Identity reduction to the half niw range.""" + return self + + def take_first_wn(self): + """Identity selection of the first bosonic frequency.""" + return self + + @property + def mat(self) -> np.ndarray: + """The backing orbital-resolved matrix.""" + return self._mat + + +def test_ornstein_zernicke_fit_aggregates_nonconverged_warnings(monkeypatch): + """All non-converging OZ fits collapse into a single aggregated warning instead of one log per orbital.""" + config.sys.n_bands = 2 + logger = mock.Mock() + monkeypatch.setattr(config, "logger", logger, raising=False) + monkeypatch.setattr(nonlocal_sde.opt, "curve_fit", mock.Mock(side_effect=RuntimeError("forced non-convergence"))) + + perform_ornstein_zernicke_fit(_ConstantChi(np.ones((2, 2, 1, 2, 2, 2, 2), dtype=np.complex64))) + + logger.warning.assert_called_once() + msg = logger.warning.call_args.args[0] + assert "16 orbital combination(s)" in msg + assert "(1, 1, 1, 1)" in msg and "(2, 2, 2, 2)" in msg # 1-based orbital labels, not 0-based + + +def test_ornstein_zernicke_fit_logs_no_warning_when_all_converge(monkeypatch): + """A fully converging set of OZ fits emits no warning at all (the aggregation guard stays silent).""" + config.sys.n_bands = 2 + logger = mock.Mock() + monkeypatch.setattr(config, "logger", logger, raising=False) + monkeypatch.setattr(nonlocal_sde.opt, "curve_fit", mock.Mock(return_value=(np.array([1.0, 2.0]), None))) + + perform_ornstein_zernicke_fit(_ConstantChi(np.ones((2, 2, 1, 2, 2, 2, 2), dtype=np.complex64))) + + logger.warning.assert_not_called()