From 3a2a1e6eca0b69deb09be3c66ae1eb04dc730d1e Mon Sep 17 00:00:00 2001 From: eyjafjallac Date: Thu, 2 Apr 2026 21:29:05 +0800 Subject: [PATCH 1/5] Add wannierization functions to `wannier.py` - Implemented `projective_wannierization` for transforming target bands using seeding states in momentum space. - Added `wannier_projection` for projective wannierization with real-space localized seed states, including Fourier transformation to momentum space. - Included parameter validation and warnings for singular value decomposition (SVD) thresholds. --- src/qrg/wannier.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/qrg/wannier.py diff --git a/src/qrg/wannier.py b/src/qrg/wannier.py new file mode 100644 index 0000000..2989e7d --- /dev/null +++ b/src/qrg/wannier.py @@ -0,0 +1,96 @@ +import warnings +from qten.linalg.tensors import Tensor +from qten.linalg.decompose import svd +from qten.symbolics.state_space import MomentumSpace, IndexSpace +from qten.symbolics.hilbert_space import HilbertSpace +from qten.geometries.fourier import fourier_transform + +def projective_wannierization( + eigenvectors: Tensor, + seeds: Tensor, + svd_threshold: float = 1e-1 +) -> Tensor: + """ + Perform projective wannierization on the target bands with the seeding states in momentum space. + + Parameters + ---------- + eigenvectors : Tensor + Target bands/eigenvectors. Expected shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + seeds : Tensor + Seed states in momentum space. Expected shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + svd_threshold : float + Warn if the minimum singular value drops below this, indicating linearly dependent seeds + or poor overlap with target bands. + + Returns + ------- + Tensor + Wannierized states with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + """ + if eigenvectors.rank() != 3 or seeds.rank() != 3: + raise ValueError("Both eigenvectors and seeds must be rank-3 Tensors.") + + # 1. Compute the overlap matrix for each momentum sector + # P_k = \psi_k^\dagger S_k + # Resulting shape: (MomentumSpace, IndexSpace_bands, IndexSpace_seeds) + overlap = eigenvectors.h(-2, -1) @ seeds + + # 2. Perform SVD on the overlap matrix + U, S, Vh = svd(overlap) + + # Check for linear dependence / poor projection + min_svd_val = S.data.min().item() + if min_svd_val < svd_threshold: + warnings.warn( + f"Precarious wannier projection with minimum svd value of {min_svd_val:.4g}" + ) + + # 3. Construct the unitary transformation matrix + # M_k = U_k V_k^\dagger + unitary = U @ Vh + + # 4. Rotate the target bands into the Wannier gauge + # W_k = \psi_k M_k + wannier_states = eigenvectors @ unitary + + return wannier_states + +def wannier_projection( + eigenvectors: Tensor, + seeds: Tensor, + svd_threshold: float = 1e-1 +) -> Tensor: + """ + Perform projective wannierization using real-space localized seed states. + + Parameters + ---------- + eigenvectors : Tensor + Target bands with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + seeds : Tensor + Seed states localized in real space with shape `(HilbertSpace_local, IndexSpace)`. + svd_threshold : float + SVD warning threshold. + + Returns + ------- + Tensor + Wannierized states in momentum space. + """ + if not isinstance(eigenvectors.dims[0], MomentumSpace): + raise TypeError("The first dimension of the eigenvectors must be a MomentumSpace.") + + kspace = eigenvectors.dims[0] + outspace = eigenvectors.dims[1] + inspace_local = seeds.dims[0] + + # Perform Fourier transform on local seeds to move them to momentum space + # f shape: (MomentumSpace, HilbertSpace_out, HilbertSpace_in_local) + f = fourier_transform(kspace, outspace, inspace_local, device=eigenvectors.device) + + # Map the seeds to crystal momentum seeds + # f @ local_seeds -> (MomentumSpace, HilbertSpace_out, IndexSpace) + crystal_seeds = f @ seeds + + return projective_wannierization(eigenvectors, crystal_seeds, svd_threshold) From 41d446fbb847e2741ab3bb236b1b7e948407e668 Mon Sep 17 00:00:00 2001 From: eyjafjallac Date: Thu, 2 Apr 2026 21:39:33 +0800 Subject: [PATCH 2/5] Remove test_smoke.py file as it is no longer needed for the project. --- tests/test_smoke.py | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 tests/test_smoke.py diff --git a/tests/test_smoke.py b/tests/test_smoke.py deleted file mode 100644 index c674d14..0000000 --- a/tests/test_smoke.py +++ /dev/null @@ -1,5 +0,0 @@ -import qrg - - -def test_package_imports() -> None: - assert qrg.__version__ From 828748539d42115cdaecfc1da8d3ced71638a9f0 Mon Sep 17 00:00:00 2001 From: eyjafjallac Date: Thu, 2 Apr 2026 22:05:38 +0800 Subject: [PATCH 3/5] Add mypy configuration and enhance wannier functions - Updated `pyproject.toml` to include mypy overrides for the `sympy` module to ignore missing imports. - Modified `tox.ini` to change the mypy command to use `uv run`. - Enhanced type annotations in `wannier.py` for better clarity and type safety. - Added a new test suite in `test_wannier.py` to validate the functionality of the wannierization methods, including edge cases and expected behaviors. --- pyproject.toml | 4 + src/qrg/wannier.py | 53 ++++++----- tests/test_wannier.py | 203 ++++++++++++++++++++++++++++++++++++++++++ tox.ini | 2 +- 4 files changed, 238 insertions(+), 24 deletions(-) create mode 100644 tests/test_wannier.py diff --git a/pyproject.toml b/pyproject.toml index 923300c..002720e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,3 +135,7 @@ files = ["src", "tests"] strict = true warn_unused_ignores = true disallow_any_generics = true + +[[tool.mypy.overrides]] +module = ["sympy", "sympy.*"] +ignore_missing_imports = true diff --git a/src/qrg/wannier.py b/src/qrg/wannier.py index 2989e7d..7a1447e 100644 --- a/src/qrg/wannier.py +++ b/src/qrg/wannier.py @@ -1,15 +1,16 @@ import warnings +from typing import Any, cast + +from qten.geometries.fourier import fourier_transform +from qten.linalg.decompose import svd from qten.linalg.tensors import Tensor -from qten.linalg.decompose import svd -from qten.symbolics.state_space import MomentumSpace, IndexSpace from qten.symbolics.hilbert_space import HilbertSpace -from qten.geometries.fourier import fourier_transform +from qten.symbolics.state_space import MomentumSpace + def projective_wannierization( - eigenvectors: Tensor, - seeds: Tensor, - svd_threshold: float = 1e-1 -) -> Tensor: + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: """ Perform projective wannierization on the target bands with the seeding states in momentum space. @@ -30,37 +31,38 @@ def projective_wannierization( """ if eigenvectors.rank() != 3 or seeds.rank() != 3: raise ValueError("Both eigenvectors and seeds must be rank-3 Tensors.") - + # 1. Compute the overlap matrix for each momentum sector # P_k = \psi_k^\dagger S_k # Resulting shape: (MomentumSpace, IndexSpace_bands, IndexSpace_seeds) overlap = eigenvectors.h(-2, -1) @ seeds - + # 2. Perform SVD on the overlap matrix U, S, Vh = svd(overlap) - + # Check for linear dependence / poor projection min_svd_val = S.data.min().item() if min_svd_val < svd_threshold: warnings.warn( - f"Precarious wannier projection with minimum svd value of {min_svd_val:.4g}" + f"Precarious wannier projection with minimum svd value of {min_svd_val:.4g}", + UserWarning, + stacklevel=2, ) - + # 3. Construct the unitary transformation matrix # M_k = U_k V_k^\dagger unitary = U @ Vh - + # 4. Rotate the target bands into the Wannier gauge # W_k = \psi_k M_k wannier_states = eigenvectors @ unitary - - return wannier_states + + return cast(Tensor[Any], wannier_states) + def wannier_projection( - eigenvectors: Tensor, - seeds: Tensor, - svd_threshold: float = 1e-1 -) -> Tensor: + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: """ Perform projective wannierization using real-space localized seed states. @@ -80,17 +82,22 @@ def wannier_projection( """ if not isinstance(eigenvectors.dims[0], MomentumSpace): raise TypeError("The first dimension of the eigenvectors must be a MomentumSpace.") - + kspace = eigenvectors.dims[0] outspace = eigenvectors.dims[1] inspace_local = seeds.dims[0] - + if not isinstance(outspace, HilbertSpace) or not isinstance(inspace_local, HilbertSpace): + raise TypeError( + "The second dimension of eigenvectors and first dimension " + "of seeds must be HilbertSpace." + ) + # Perform Fourier transform on local seeds to move them to momentum space # f shape: (MomentumSpace, HilbertSpace_out, HilbertSpace_in_local) f = fourier_transform(kspace, outspace, inspace_local, device=eigenvectors.device) - + # Map the seeds to crystal momentum seeds # f @ local_seeds -> (MomentumSpace, HilbertSpace_out, IndexSpace) crystal_seeds = f @ seeds - + return projective_wannierization(eigenvectors, crystal_seeds, svd_threshold) diff --git a/tests/test_wannier.py b/tests/test_wannier.py new file mode 100644 index 0000000..aab49ac --- /dev/null +++ b/tests/test_wannier.py @@ -0,0 +1,203 @@ +from dataclasses import dataclass +from typing import Any + +import pytest +import sympy as sy +import torch +from qten.geometries.boundary import PeriodicBoundary +from qten.geometries.fourier import fourier_transform +from qten.geometries.spatials import Lattice, Offset +from qten.linalg.tensors import Tensor +from qten.symbolics.hilbert_space import HilbertSpace, U1Basis +from qten.symbolics.state_space import IndexSpace, brillouin_zone +from sympy import ImmutableDenseMatrix + +from qrg.wannier import projective_wannierization, wannier_projection + + +@dataclass(frozen=True) +class Orb: + name: str + + +def _state(r: Offset[Any], orb: str = "s") -> U1Basis: + return U1Basis(coef=sy.Integer(1), base=(r, Orb(orb))) + + +def _build_1d_spaces() -> tuple[Lattice, Any, HilbertSpace]: + lattice = Lattice( + basis=ImmutableDenseMatrix([[1]]), + boundaries=PeriodicBoundary(ImmutableDenseMatrix.diag(2)), + unit_cell={"r": ImmutableDenseMatrix([0])}, + ) + k_space = brillouin_zone(lattice.dual) + r0 = Offset(rep=ImmutableDenseMatrix([0]), space=lattice.affine) + r_half = Offset(rep=ImmutableDenseMatrix([sy.Rational(1, 2)]), space=lattice.affine) + bloch_space = HilbertSpace.new([_state(r0, "a"), _state(r_half, "b")]) + return lattice, k_space, bloch_space + + +def test_wannier_projection_matches_explicit_crystal_seed_pipeline() -> None: + """Test local-seed projection matches explicit crystal-seed workflow.""" + # Minimal 1D lattice and Brillouin zone, similar to the notebook flow. + _, k_space, bloch_space = _build_1d_spaces() + local_space = bloch_space + + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + # (K, B, I): one target band at each k with deterministic phase structure. + eigenvectors = Tensor( + data=torch.tensor( + [ + [[2**-0.5], [2**-0.5]], + [[2**-0.5], [-(2**-0.5)]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + + # (B_local, I): one local seed orbital. + local_seeds = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(local_space, seed_space), + ) + + # Notebook-like pathway: local seeds -> Fourier seeds -> projective wannierization. + crystal_seeds = fourier_transform(k_space, bloch_space, local_space) @ local_seeds + expected = projective_wannierization(eigenvectors=eigenvectors, seeds=crystal_seeds) + actual = wannier_projection(eigenvectors=eigenvectors, seeds=local_seeds) + + assert actual.dims == expected.dims + assert torch.allclose(actual.data, expected.data) + + # Result should remain orthonormal within the selected band subspace. + overlap = actual.h(-2, -1) @ actual + assert torch.allclose( + overlap.data, + torch.ones((k_space.dim, 1, 1), dtype=torch.complex128), + ) + + +def test_projective_wannierization_rejects_non_rank3_tensors() -> None: + """Test rank validation raises when inputs are not rank-3 tensors.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors_rank2 = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, band_space), + ) + seeds_rank3 = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, seed_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, seed_space), + ) + with pytest.raises(ValueError, match="rank-3"): + projective_wannierization(eigenvectors=eigenvectors_rank2, seeds=seeds_rank3) + + eigenvectors_rank3 = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, band_space), + ) + seeds_rank2 = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + with pytest.raises(ValueError, match="rank-3"): + projective_wannierization(eigenvectors=eigenvectors_rank3, seeds=seeds_rank2) + + +def test_wannier_projection_rejects_non_momentum_first_dimension() -> None: + """Test wannier_projection rejects non-MomentumSpace first tensor dim.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + bad_k_space = IndexSpace.linear(2) + eigenvectors = Tensor( + data=torch.ones((bad_k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(bad_k_space, bloch_space, band_space), + ) + local_seeds = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + with pytest.raises(TypeError, match="MomentumSpace"): + wannier_projection(eigenvectors=eigenvectors, seeds=local_seeds) + + good_eigenvectors = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, band_space), + ) + wannier_projection(eigenvectors=good_eigenvectors, seeds=local_seeds) + + +def test_projective_wannierization_warns_on_poor_overlap() -> None: + """Test poor seed-band overlap emits the precarious projection warning.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + # Build nearly orthogonal eigenvector/seed overlap to trigger warning. + eigenvectors = Tensor( + data=torch.tensor( + [ + [[1.0], [0.0]], + [[1.0], [0.0]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + tiny = 1.0e-8 + seeds = Tensor( + data=torch.tensor( + [ + [[tiny], [1.0]], + [[tiny], [1.0]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, seed_space), + ) + + with pytest.warns(UserWarning, match="Precarious wannier projection"): + _ = projective_wannierization( + eigenvectors=eigenvectors, + seeds=seeds, + svd_threshold=1.0e-3, + ) + + +def test_wannier_projection_projector_is_gauge_invariant() -> None: + """Test projector is invariant between equivalent seed construction routes.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors = Tensor( + data=torch.tensor( + [ + [[2**-0.5], [2**-0.5]], + [[2**-0.5], [-(2**-0.5)]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + local_seeds = Tensor( + data=torch.tensor([[0.0], [1.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + + w_local = wannier_projection(eigenvectors=eigenvectors, seeds=local_seeds) + w_crystal = projective_wannierization( + eigenvectors=eigenvectors, + seeds=fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds, + ) + + p_local = w_local @ w_local.h(-2, -1) + p_crystal = w_crystal @ w_crystal.h(-2, -1) + assert torch.allclose(p_local.data, p_crystal.data) diff --git a/tox.ini b/tox.ini index ff351d1..155ab65 100644 --- a/tox.ini +++ b/tox.ini @@ -32,4 +32,4 @@ package = skip dependency_groups = dev commands = - mypy --no-incremental src tests + uv run --active mypy --no-incremental src tests From 646362f57c9dabbd74ef70a898bc69d56f329bc0 Mon Sep 17 00:00:00 2001 From: eyjafjallac Date: Thu, 2 Apr 2026 22:22:00 +0800 Subject: [PATCH 4/5] Refactor wannier functions and update mypy configuration - Renamed `projective_wannierization` to `wannierize_k` and `wannier_projection` to `wannierize_r` for improved clarity. - Updated `tox.ini` to modify the mypy command by removing the tests directory from the command. - Added type stubs for the new function names in `wannier.pyi`. - Adjusted tests in `test_wannier.py` to reflect the new function names and ensure proper validation. --- src/qrg/wannier.py | 6 +++--- src/qrg/wannier.pyi | 10 ++++++++++ tests/test_wannier.py | 32 ++++++++++++++++---------------- tox.ini | 2 +- 4 files changed, 30 insertions(+), 20 deletions(-) create mode 100644 src/qrg/wannier.pyi diff --git a/src/qrg/wannier.py b/src/qrg/wannier.py index 7a1447e..62baa79 100644 --- a/src/qrg/wannier.py +++ b/src/qrg/wannier.py @@ -8,7 +8,7 @@ from qten.symbolics.state_space import MomentumSpace -def projective_wannierization( +def wannierize_k( eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 ) -> Tensor[Any]: """ @@ -60,7 +60,7 @@ def projective_wannierization( return cast(Tensor[Any], wannier_states) -def wannier_projection( +def wannierize_r( eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 ) -> Tensor[Any]: """ @@ -100,4 +100,4 @@ def wannier_projection( # f @ local_seeds -> (MomentumSpace, HilbertSpace_out, IndexSpace) crystal_seeds = f @ seeds - return projective_wannierization(eigenvectors, crystal_seeds, svd_threshold) + return wannierize_k(eigenvectors, crystal_seeds, svd_threshold) diff --git a/src/qrg/wannier.pyi b/src/qrg/wannier.pyi new file mode 100644 index 0000000..7be0fdc --- /dev/null +++ b/src/qrg/wannier.pyi @@ -0,0 +1,10 @@ +from typing import Any + +from qten.linalg.tensors import Tensor + +def wannierize_k( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: ... +def wannierize_r( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: ... diff --git a/tests/test_wannier.py b/tests/test_wannier.py index aab49ac..b84807d 100644 --- a/tests/test_wannier.py +++ b/tests/test_wannier.py @@ -12,7 +12,7 @@ from qten.symbolics.state_space import IndexSpace, brillouin_zone from sympy import ImmutableDenseMatrix -from qrg.wannier import projective_wannierization, wannier_projection +from qrg.wannier import wannierize_k, wannierize_r @dataclass(frozen=True) @@ -37,7 +37,7 @@ def _build_1d_spaces() -> tuple[Lattice, Any, HilbertSpace]: return lattice, k_space, bloch_space -def test_wannier_projection_matches_explicit_crystal_seed_pipeline() -> None: +def test_wannierize_r_matches_explicit_crystal_seed_pipeline() -> None: """Test local-seed projection matches explicit crystal-seed workflow.""" # Minimal 1D lattice and Brillouin zone, similar to the notebook flow. _, k_space, bloch_space = _build_1d_spaces() @@ -66,8 +66,8 @@ def test_wannier_projection_matches_explicit_crystal_seed_pipeline() -> None: # Notebook-like pathway: local seeds -> Fourier seeds -> projective wannierization. crystal_seeds = fourier_transform(k_space, bloch_space, local_space) @ local_seeds - expected = projective_wannierization(eigenvectors=eigenvectors, seeds=crystal_seeds) - actual = wannier_projection(eigenvectors=eigenvectors, seeds=local_seeds) + expected = wannierize_k(eigenvectors=eigenvectors, seeds=crystal_seeds) + actual = wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) assert actual.dims == expected.dims assert torch.allclose(actual.data, expected.data) @@ -80,7 +80,7 @@ def test_wannier_projection_matches_explicit_crystal_seed_pipeline() -> None: ) -def test_projective_wannierization_rejects_non_rank3_tensors() -> None: +def test_wannierize_k_rejects_non_rank3_tensors() -> None: """Test rank validation raises when inputs are not rank-3 tensors.""" _, k_space, bloch_space = _build_1d_spaces() band_space = IndexSpace.linear(1) @@ -95,7 +95,7 @@ def test_projective_wannierization_rejects_non_rank3_tensors() -> None: dims=(k_space, bloch_space, seed_space), ) with pytest.raises(ValueError, match="rank-3"): - projective_wannierization(eigenvectors=eigenvectors_rank2, seeds=seeds_rank3) + wannierize_k(eigenvectors=eigenvectors_rank2, seeds=seeds_rank3) eigenvectors_rank3 = Tensor( data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), @@ -106,11 +106,11 @@ def test_projective_wannierization_rejects_non_rank3_tensors() -> None: dims=(bloch_space, seed_space), ) with pytest.raises(ValueError, match="rank-3"): - projective_wannierization(eigenvectors=eigenvectors_rank3, seeds=seeds_rank2) + wannierize_k(eigenvectors=eigenvectors_rank3, seeds=seeds_rank2) -def test_wannier_projection_rejects_non_momentum_first_dimension() -> None: - """Test wannier_projection rejects non-MomentumSpace first tensor dim.""" +def test_wannierize_r_rejects_non_momentum_first_dimension() -> None: + """Test wannierize_r rejects non-MomentumSpace first tensor dim.""" _, k_space, bloch_space = _build_1d_spaces() band_space = IndexSpace.linear(1) seed_space = IndexSpace.linear(1) @@ -125,16 +125,16 @@ def test_wannier_projection_rejects_non_momentum_first_dimension() -> None: dims=(bloch_space, seed_space), ) with pytest.raises(TypeError, match="MomentumSpace"): - wannier_projection(eigenvectors=eigenvectors, seeds=local_seeds) + wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) good_eigenvectors = Tensor( data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), dims=(k_space, bloch_space, band_space), ) - wannier_projection(eigenvectors=good_eigenvectors, seeds=local_seeds) + wannierize_r(eigenvectors=good_eigenvectors, seeds=local_seeds) -def test_projective_wannierization_warns_on_poor_overlap() -> None: +def test_wannierize_k_warns_on_poor_overlap() -> None: """Test poor seed-band overlap emits the precarious projection warning.""" _, k_space, bloch_space = _build_1d_spaces() band_space = IndexSpace.linear(1) @@ -164,14 +164,14 @@ def test_projective_wannierization_warns_on_poor_overlap() -> None: ) with pytest.warns(UserWarning, match="Precarious wannier projection"): - _ = projective_wannierization( + _ = wannierize_k( eigenvectors=eigenvectors, seeds=seeds, svd_threshold=1.0e-3, ) -def test_wannier_projection_projector_is_gauge_invariant() -> None: +def test_wannierize_r_projector_is_gauge_invariant() -> None: """Test projector is invariant between equivalent seed construction routes.""" _, k_space, bloch_space = _build_1d_spaces() band_space = IndexSpace.linear(1) @@ -192,8 +192,8 @@ def test_wannier_projection_projector_is_gauge_invariant() -> None: dims=(bloch_space, seed_space), ) - w_local = wannier_projection(eigenvectors=eigenvectors, seeds=local_seeds) - w_crystal = projective_wannierization( + w_local = wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds) + w_crystal = wannierize_k( eigenvectors=eigenvectors, seeds=fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds, ) diff --git a/tox.ini b/tox.ini index 155ab65..dcc9bad 100644 --- a/tox.ini +++ b/tox.ini @@ -32,4 +32,4 @@ package = skip dependency_groups = dev commands = - uv run --active mypy --no-incremental src tests + uv run --active mypy --no-incremental src From acfd8c809a5ad9e5a9d9b478b3724ded27208619 Mon Sep 17 00:00:00 2001 From: eyjafjallac Date: Sat, 4 Apr 2026 00:23:49 +0800 Subject: [PATCH 5/5] Add projective wannierization function and corresponding tests - Introduced `projective_wannierization` to handle automatic seed-space dispatch for wannierization. - Updated type stubs in `wannier.pyi` to include the new function. - Added tests in `test_wannier.py` to validate the functionality and error handling of the new method, ensuring it matches existing wannierization functions. --- src/qrg/wannier.py | 34 +++++++++++++++++++++++++++ src/qrg/wannier.pyi | 3 +++ tests/test_wannier.py | 53 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/qrg/wannier.py b/src/qrg/wannier.py index 62baa79..bf3dff7 100644 --- a/src/qrg/wannier.py +++ b/src/qrg/wannier.py @@ -101,3 +101,37 @@ def wannierize_r( crystal_seeds = f @ seeds return wannierize_k(eigenvectors, crystal_seeds, svd_threshold) + + +def projective_wannierization( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: + """ + Perform projective wannierization with automatic seed-space dispatch. + + Parameters + ---------- + eigenvectors : Tensor + Target bands with shape `(MomentumSpace, HilbertSpace, IndexSpace)`. + seeds : Tensor + Either crystal-momentum seeds `(MomentumSpace, HilbertSpace, IndexSpace)` + or local real-space seeds `(HilbertSpace_local, IndexSpace)`. + svd_threshold : float + SVD warning threshold. + + Returns + ------- + Tensor + Wannierized states in momentum space. + """ + if seeds.rank() == 3: + if not isinstance(seeds.dims[0], MomentumSpace): + raise TypeError("Rank-3 seeds must have MomentumSpace as the first dimension.") + return wannierize_k(eigenvectors=eigenvectors, seeds=seeds, svd_threshold=svd_threshold) + + if seeds.rank() == 2: + if not isinstance(seeds.dims[0], HilbertSpace): + raise TypeError("Rank-2 seeds must have HilbertSpace as the first dimension.") + return wannierize_r(eigenvectors=eigenvectors, seeds=seeds, svd_threshold=svd_threshold) + + raise ValueError("Seeds must be rank-2 (local seeds) or rank-3 (momentum seeds).") diff --git a/src/qrg/wannier.pyi b/src/qrg/wannier.pyi index 7be0fdc..baff027 100644 --- a/src/qrg/wannier.pyi +++ b/src/qrg/wannier.pyi @@ -8,3 +8,6 @@ def wannierize_k( def wannierize_r( eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 ) -> Tensor[Any]: ... +def projective_wannierization( + eigenvectors: Tensor[Any], seeds: Tensor[Any], svd_threshold: float = 1e-1 +) -> Tensor[Any]: ... diff --git a/tests/test_wannier.py b/tests/test_wannier.py index b84807d..298f012 100644 --- a/tests/test_wannier.py +++ b/tests/test_wannier.py @@ -12,7 +12,7 @@ from qten.symbolics.state_space import IndexSpace, brillouin_zone from sympy import ImmutableDenseMatrix -from qrg.wannier import wannierize_k, wannierize_r +from qrg.wannier import projective_wannierization, wannierize_k, wannierize_r @dataclass(frozen=True) @@ -201,3 +201,54 @@ def test_wannierize_r_projector_is_gauge_invariant() -> None: p_local = w_local @ w_local.h(-2, -1) p_crystal = w_crystal @ w_crystal.h(-2, -1) assert torch.allclose(p_local.data, p_crystal.data) + + +def test_projective_wannierization_matches_explicit_paths() -> None: + """Test auto-dispatch reproduces explicit k-space and local-space APIs.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors = Tensor( + data=torch.tensor( + [ + [[2**-0.5], [2**-0.5]], + [[2**-0.5], [-(2**-0.5)]], + ], + dtype=torch.complex128, + ), + dims=(k_space, bloch_space, band_space), + ) + local_seeds = Tensor( + data=torch.tensor([[1.0], [0.0]], dtype=torch.complex128), + dims=(bloch_space, seed_space), + ) + crystal_seeds = fourier_transform(k_space, bloch_space, bloch_space) @ local_seeds + + assert torch.allclose( + projective_wannierization(eigenvectors=eigenvectors, seeds=local_seeds).data, + wannierize_r(eigenvectors=eigenvectors, seeds=local_seeds).data, + ) + assert torch.allclose( + projective_wannierization(eigenvectors=eigenvectors, seeds=crystal_seeds).data, + wannierize_k(eigenvectors=eigenvectors, seeds=crystal_seeds).data, + ) + + +def test_projective_wannierization_rejects_invalid_seed_rank() -> None: + """Test auto-dispatch raises when seeds rank is neither 2 nor 3.""" + _, k_space, bloch_space = _build_1d_spaces() + band_space = IndexSpace.linear(1) + seed_space = IndexSpace.linear(1) + + eigenvectors = Tensor( + data=torch.ones((k_space.dim, bloch_space.dim, band_space.dim), dtype=torch.complex128), + dims=(k_space, bloch_space, band_space), + ) + bad_seeds = Tensor( + data=torch.ones((seed_space.dim,), dtype=torch.complex128), + dims=(seed_space,), + ) + + with pytest.raises(ValueError, match="rank-2|rank-3"): + projective_wannierization(eigenvectors=eigenvectors, seeds=bad_seeds)