diff --git a/pyproject.toml b/pyproject.toml index 923300c..002720e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,3 +135,7 @@ files = ["src", "tests"] strict = true warn_unused_ignores = true disallow_any_generics = true + +[[tool.mypy.overrides]] +module = ["sympy", "sympy.*"] +ignore_missing_imports = true diff --git a/src/qrg/wannier.py b/src/qrg/wannier.py new file mode 100644 index 0000000..bf3dff7 --- /dev/null +++ b/src/qrg/wannier.py @@ -0,0 +1,137 @@ +import warnings +from typing import Any, cast + +from qten.geometries.fourier import fourier_transform +from qten.linalg.decompose import svd +from qten.linalg.tensors import Tensor +from qten.symbolics.hilbert_space import HilbertSpace +from qten.symbolics.state_space import MomentumSpace + + +def wannierize_k( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: + """ + Perform projective wannierization on the target bands with the seeding states in momentum space. + + Parameters + ---------- + eigenvectors : Tensor + Target bands/eigenvectors. Expected shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + seeds : Tensor + Seed states in momentum space. Expected shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + svd_threshold : float + Warn if the minimum singular value drops below this, indicating linearly dependent seeds + or poor overlap with target bands. + + Returns + ------- + Tensor + Wannierized states with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + """ + if eigenvectors.rank() != 3 or seeds.rank() != 3: + raise ValueError("Both eigenvectors and seeds must be rank-3 Tensors.") + + # 1. Compute the overlap matrix for each momentum sector + # P_k = \psi_k^\dagger S_k + # Resulting shape: (MomentumSpace, IndexSpace_bands, IndexSpace_seeds) + overlap = eigenvectors.h(-2, -1) @ seeds + + # 2. Perform SVD on the overlap matrix + U, S, Vh = svd(overlap) + + # Check for linear dependence / poor projection + min_svd_val = S.data.min().item() + if min_svd_val < svd_threshold: + warnings.warn( + f"Precarious wannier projection with minimum svd value of {min_svd_val:.4g}", + UserWarning, + stacklevel=2, + ) + + # 3. Construct the unitary transformation matrix + # M_k = U_k V_k^\dagger + unitary = U @ Vh + + # 4. Rotate the target bands into the Wannier gauge + # W_k = \psi_k M_k + wannier_states = eigenvectors @ unitary + + return cast(Tensor[Any], wannier_states) + + +def wannierize_r( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: + """ + Perform projective wannierization using real-space localized seed states. + + Parameters + ---------- + eigenvectors : Tensor + Target bands with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + seeds : Tensor + Seed states localized in real space with shape `(HilbertSpace_local, IndexSpace)`. + svd_threshold : float + SVD warning threshold. + + Returns + ------- + Tensor + Wannierized states in momentum space. + """ + if not isinstance(eigenvectors.dims[0], MomentumSpace): + raise TypeError("The first dimension of the eigenvectors must be a MomentumSpace.") + + kspace = eigenvectors.dims[0] + outspace = eigenvectors.dims[1] + inspace_local = seeds.dims[0] + if not isinstance(outspace, HilbertSpace) or not isinstance(inspace_local, HilbertSpace): + raise TypeError( + "The second dimension of eigenvectors and first dimension " + "of seeds must be HilbertSpace." + ) + + # Perform Fourier transform on local seeds to move them to momentum space + # f shape: (MomentumSpace, HilbertSpace_out, HilbertSpace_in_local) + f = fourier_transform(kspace, outspace, inspace_local, device=eigenvectors.device) + + # Map the seeds to crystal momentum seeds + # f @ local_seeds -> (MomentumSpace, HilbertSpace_out, IndexSpace) + crystal_seeds = f @ seeds + + return wannierize_k(eigenvectors, crystal_seeds, svd_threshold) + + +def projective_wannierization( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: + """ + Perform projective wannierization with automatic seed-space dispatch. + + Parameters + ---------- + eigenvectors : Tensor + Target bands with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + seeds : Tensor + Either crystal-momentum seeds `(MomentumSpace, HilbertSpace, IndexSpace)` + or local real-space seeds `(HilbertSpace_local, IndexSpace)`. + svd_threshold : float + SVD warning threshold. + + Returns + ------- + Tensor + Wannierized states in momentum space. + """ + if seeds.rank() == 3: + if not isinstance(seeds.dims[0], MomentumSpace): + raise TypeError("Rank-3 seeds must have MomentumSpace as the first dimension.") + return wannierize_k(eigenvectors=eigenvectors, seeds=seeds, svd_threshold=svd_threshold) + + if seeds.rank() == 2: + if not isinstance(seeds.dims[0], HilbertSpace): + raise TypeError("Rank-2 seeds must have HilbertSpace as the first dimension.") + return wannierize_r(eigenvectors=eigenvectors, seeds=seeds, svd_threshold=svd_threshold) + + raise ValueError("Seeds must be rank-2 (local seeds) or rank-3 (momentum seeds).") diff --git a/src/qrg/wannier.pyi b/src/qrg/wannier.pyi new file mode 100644 index 0000000..baff027 --- /dev/null +++ b/src/qrg/wannier.pyi @@ -0,0 +1,13 @@ +from typing import Any + +from qten.linalg.tensors import Tensor + +def wannierize_k( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: ... +def wannierize_r( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: ... +def projective_wannierization( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: ... diff --git a/tests/test_smoke.py b/tests/test_smoke.py deleted file mode 100644 index c674d14..0000000 --- a/tests/test_smoke.py +++ /dev/null @@ -1,5 +0,0 @@ -import qrg - - -def test_package_imports() -> None: - assert qrg.__version__ diff --git a/tests/test_wannier.py b/tests/test_wannier.py new file mode 100644 index 0000000..298f012 --- /dev/null +++ b/tests/test_wannier.py @@ -0,0 +1,254 @@ +from dataclasses import dataclass +from typing import Any + +import pytest +import sympy as sy +import torch +from qten.geometries.boundary import PeriodicBoundary +from qten.geometries.fourier import fourier_transform +from qten.geometries.spatials import Lattice, Offset +from qten.linalg.tensors import Tensor +from qten.symbolics.hilbert_space import HilbertSpace, U1Basis +from qten.symbolics.state_space import IndexSpace, brillouin_zone +from sympy import ImmutableDenseMatrix + +from qrg.wannier import projective_wannierization, wannierize_k, wannierize_r + + +@dataclass(frozen=True) +class Orb: + name: str + + +def _state(r: Offset[Any], orb: str = "s") -> U1Basis: + return U1Basis(coef=sy.Integer(1), base=(r, Orb(orb))) + + +def _build_1d_spaces() -> tuple[Lattice, Any, HilbertSpace]: + lattice = Lattice( + basis=ImmutableDenseMatrix([[1]]), + boundaries=PeriodicBoundary(ImmutableDenseMatrix.diag(2)), + unit_cell={"r": ImmutableDenseMatrix([0])}, + ) + k_space = brillouin_zone(lattice.dual) + r0 = Offset(rep=ImmutableDenseMatrix([0]), space=lattice.affine) + r_half = Offset(rep=ImmutableDenseMatrix([sy.Rational(1, 2)]), space=lattice.affine) + bloch_space = HilbertSpace.new([_state(r0, "a"), _state(r_half, "b")]) + return lattice, k_space, bloch_space + + +def test_wannierize_r_matches_explicit_crystal_seed_pipeline() -> None: + """Test local-seed projection matches explicit crystal-seed workflow.""" + # Minimal 1D lattice and Brillouin zone, similar to the notebook flow. + _, k_space, bloch_space = _build_1d_spaces() + local_space = bloch_space + + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + # (K, B, I): one target band at each k with deterministic phase structure. + eigenvectors = Tensor( + data=torch.tensor( + [ + [[2**-0.5], [2**-0.5]], + [[2**-0.5], [-(2**-0.5)]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + + # (B_local, I): one local seed orbital. + local_seeds = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(local_space, seed_space), + ) + + # Notebook-like pathway: local seeds -> Fourier seeds -> projective wannierization. + crystal_seeds = fourier_transform(k_space, bloch_space, local_space) @ local_seeds + expected = wannierize_k(eigenvectors=eigenvectors, seeds=crystal_seeds) + actual = wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) + + assert actual.dims == expected.dims + assert torch.allclose(actual.data, expected.data) + + # Result should remain orthonormal within the selected band subspace. + overlap = actual.h(-2, -1) @ actual + assert torch.allclose( + overlap.data, + torch.ones((k_space.dim, 1, 1), dtype=torch.complex128), + ) + + +def test_wannierize_k_rejects_non_rank3_tensors() -> None: + """Test rank validation raises when inputs are not rank-3 tensors.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors_rank2 = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, band_space), + ) + seeds_rank3 = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, seed_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, seed_space), + ) + with pytest.raises(ValueError, match="rank-3"): + wannierize_k(eigenvectors=eigenvectors_rank2, seeds=seeds_rank3) + + eigenvectors_rank3 = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, band_space), + ) + seeds_rank2 = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + with pytest.raises(ValueError, match="rank-3"): + wannierize_k(eigenvectors=eigenvectors_rank3, seeds=seeds_rank2) + + +def test_wannierize_r_rejects_non_momentum_first_dimension() -> None: + """Test wannierize_r rejects non-MomentumSpace first tensor dim.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + bad_k_space = IndexSpace.linear(2) + eigenvectors = Tensor( + data=torch.ones((bad_k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(bad_k_space, bloch_space, band_space), + ) + local_seeds = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + with pytest.raises(TypeError, match="MomentumSpace"): + wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) + + good_eigenvectors = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, band_space), + ) + wannierize_r(eigenvectors=good_eigenvectors, seeds=local_seeds) + + +def test_wannierize_k_warns_on_poor_overlap() -> None: + """Test poor seed-band overlap emits the precarious projection warning.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + # Build nearly orthogonal eigenvector/seed overlap to trigger warning. + eigenvectors = Tensor( + data=torch.tensor( + [ + [[1.0], [0.0]], + [[1.0], [0.0]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + tiny = 1.0e-8 + seeds = Tensor( + data=torch.tensor( + [ + [[tiny], [1.0]], + [[tiny], [1.0]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, seed_space), + ) + + with pytest.warns(UserWarning, match="Precarious wannier projection"): + _ = wannierize_k( + eigenvectors=eigenvectors, + seeds=seeds, + svd_threshold=1.0e-3, + ) + + +def test_wannierize_r_projector_is_gauge_invariant() -> None: + """Test projector is invariant between equivalent seed construction routes.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors = Tensor( + data=torch.tensor( + [ + [[2**-0.5], [2**-0.5]], + [[2**-0.5], [-(2**-0.5)]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + local_seeds = Tensor( + data=torch.tensor([[0.0], [1.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + + w_local = wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) + w_crystal = wannierize_k( + eigenvectors=eigenvectors, + seeds=fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds, + ) + + p_local = w_local @ w_local.h(-2, -1) + p_crystal = w_crystal @ w_crystal.h(-2, -1) + assert torch.allclose(p_local.data, p_crystal.data) + + +def test_projective_wannierization_matches_explicit_paths() -> None: + """Test auto-dispatch reproduces explicit k-space and local-space APIs.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors = Tensor( + data=torch.tensor( + [ + [[2**-0.5], [2**-0.5]], + [[2**-0.5], [-(2**-0.5)]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + local_seeds = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + crystal_seeds = fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds + + assert torch.allclose( + projective_wannierization(eigenvectors=eigenvectors, seeds=local_seeds).data, + wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds).data, + ) + assert torch.allclose( + projective_wannierization(eigenvectors=eigenvectors, seeds=crystal_seeds).data, + wannierize_k(eigenvectors=eigenvectors, seeds=crystal_seeds).data, + ) + + +def test_projective_wannierization_rejects_invalid_seed_rank() -> None: + """Test auto-dispatch raises when seeds rank is neither 2 nor 3.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, band_space), + ) + bad_seeds = Tensor( + data=torch.ones((seed_space.dim,), dtype=torch.complex128), + dims=(seed_space,), + ) + + with pytest.raises(ValueError, match="rank-2|rank-3"): + projective_wannierization(eigenvectors=eigenvectors, seeds=bad_seeds) diff --git a/tox.ini b/tox.ini index ff351d1..dcc9bad 100644 --- a/tox.ini +++ b/tox.ini @@ -32,4 +32,4 @@ package = skip dependency_groups = dev commands = - mypy --no-incremental src tests + uv run --active mypy --no-incremental src