Skip to content

Commit 807a8c6

Browse files
committed
Cut off multiplets in truncated SVD
1 parent d56ff1f commit 807a8c6

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

varipeps/ctmrg/projectors.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import jax.numpy as jnp
55
from jax import jit, checkpoint
6+
from jax.lax import scan, cond
67

78
from varipeps.peps import PEPS_Tensor
89
from varipeps.contractions import apply_contraction, apply_contraction_jitted
@@ -108,11 +109,36 @@ def _truncated_SVD(
108109
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
109110
U, S, Vh = gauge_fixed_svd(matrix)
110111

112+
if len(S) > chi:
113+
gaps = (S[:chi] - S[1 : chi + 1]) / S[0]
114+
111115
# Truncate the singular values
112116
S = S[:chi]
113117
U = U[:, :chi]
114118
Vh = Vh[:chi, :]
115119

120+
if len(S) > chi:
121+
122+
def fix_multiplets(carry, x):
123+
S_elem, gap = x
124+
(already_found,) = carry
125+
126+
trunc_cond = gap > truncation_eps
127+
already_found = jnp.logical_or(trunc_cond, already_found)
128+
129+
result = cond(
130+
already_found, lambda x: x, lambda x: jnp.zeros_like(x), S_elem
131+
)
132+
133+
return (already_found,), result
134+
135+
_, S = scan(
136+
fix_multiplets,
137+
(jnp.zeros((), dtype=bool),),
138+
(S, gaps),
139+
reverse=True,
140+
)
141+
116142
relevant_S_values = (S / S[0]) > truncation_eps
117143
S_inv_sqrt = jnp.where(
118144
relevant_S_values, 1 / jnp.sqrt(jnp.where(relevant_S_values, S, 1)), 0

0 commit comments

Comments
 (0)