Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion dgamore/nonlocal_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,24 @@ def fit_oz_spin(q_grid: KGrid, mat: np.ndarray):
chi_mat = chi.map_to_full_bz(config.lattice.q_grid).to_half_niw_range().take_first_wn().mat.real
orb_shape = (config.sys.n_bands,) * 4
oz_coeffs = np.zeros(orb_shape + (2,), dtype=float)
failed_orbitals = []

for idx in np.ndindex(orb_shape):
mat_slice = chi_mat[..., idx[0], idx[1], idx[2], idx[3]].flatten()
try:
coeffs = fit_oz_spin(config.lattice.q_grid, mat_slice) if not np.all(mat_slice == 0) else [0.0, 0.0]
except (ValueError, RuntimeError, opt.OptimizeWarning):
config.logger.warning(f"OZ fit did not converge for orbitals {idx}. Using [-1, -1].")
failed_orbitals.append(idx)
coeffs = [-1.0, -1.0]
oz_coeffs[idx] = coeffs

if failed_orbitals:
one_based = [tuple(o + 1 for o in idx) for idx in failed_orbitals]
config.logger.warning(
f"OZ fit did not converge for {len(failed_orbitals)} orbital combination(s): "
f"{one_based}. Using [-1, -1]."
)

rows = []
for idx in np.ndindex(orb_shape):
rows.append([*idx, *oz_coeffs[idx]])
Expand Down
56 changes: 55 additions & 1 deletion tests/test_nonlocal_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

import itertools
import os
from unittest import mock

import numpy as np

import dgamore.config as config
import dgamore.nonlocal_sde as nonlocal_sde
from dgamore.hamiltonian import Hamiltonian
from dgamore.interaction import Interaction
from dgamore.local_sde import get_local_hartree_fock
from dgamore.n_point_base import SpinChannel
from dgamore.nonlocal_sde import _init_mu_history, get_hartree_fock
from dgamore.nonlocal_sde import _init_mu_history, get_hartree_fock, perform_ornstein_zernicke_fit

LOCAL_SDE_DATA = f"{os.path.dirname(os.path.abspath(__file__))}/test_data/local_sde"

Expand Down Expand Up @@ -80,3 +82,55 @@ def test_nonlocal_hartree_fock_matches_local_reference():
assert np.allclose(hf_nonlocal, sigma_hf_ref[None, ...])
# the same reference is the local Hartree-Fock, so the two SDE paths agree
assert np.allclose(hf_nonlocal, get_local_hartree_fock(u_loc, occ)[None, ...])


class _ConstantChi:
"""Minimal physical-susceptibility stand-in whose BZ and frequency reductions are identities."""

def __init__(self, mat: np.ndarray):
"""Stores the orbital-resolved matrix that the reduction chain returns unchanged."""
self._mat = mat

def map_to_full_bz(self, grid):
"""Identity unfolding to the full BZ."""
return self

def to_half_niw_range(self):
"""Identity reduction to the half niw range."""
return self

def take_first_wn(self):
"""Identity selection of the first bosonic frequency."""
return self

@property
def mat(self) -> np.ndarray:
"""The backing orbital-resolved matrix."""
return self._mat


def test_ornstein_zernicke_fit_aggregates_nonconverged_warnings(monkeypatch):
"""All non-converging OZ fits collapse into a single aggregated warning instead of one log per orbital."""
config.sys.n_bands = 2
logger = mock.Mock()
monkeypatch.setattr(config, "logger", logger, raising=False)
monkeypatch.setattr(nonlocal_sde.opt, "curve_fit", mock.Mock(side_effect=RuntimeError("forced non-convergence")))

perform_ornstein_zernicke_fit(_ConstantChi(np.ones((2, 2, 1, 2, 2, 2, 2), dtype=np.complex64)))

logger.warning.assert_called_once()
msg = logger.warning.call_args.args[0]
assert "16 orbital combination(s)" in msg
assert "(1, 1, 1, 1)" in msg and "(2, 2, 2, 2)" in msg # 1-based orbital labels, not 0-based


def test_ornstein_zernicke_fit_logs_no_warning_when_all_converge(monkeypatch):
"""A fully converging set of OZ fits emits no warning at all (the aggregation guard stays silent)."""
config.sys.n_bands = 2
logger = mock.Mock()
monkeypatch.setattr(config, "logger", logger, raising=False)
monkeypatch.setattr(nonlocal_sde.opt, "curve_fit", mock.Mock(return_value=(np.array([1.0, 2.0]), None)))

perform_ornstein_zernicke_fit(_ConstantChi(np.ones((2, 2, 1, 2, 2, 2, 2), dtype=np.complex64)))

logger.warning.assert_not_called()
Loading