|
3 | 3 |
|
4 | 4 | import jax.numpy as jnp |
5 | 5 | from jax import jit, checkpoint |
| 6 | +from jax.lax import scan, cond |
6 | 7 |
|
7 | 8 | from varipeps.peps import PEPS_Tensor |
8 | 9 | from varipeps.contractions import apply_contraction, apply_contraction_jitted |
@@ -108,11 +109,36 @@ def _truncated_SVD( |
108 | 109 | ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: |
109 | 110 | U, S, Vh = gauge_fixed_svd(matrix) |
110 | 111 |
|
| 112 | + if len(S) > chi: |
| 113 | + gaps = (S[:chi] - S[1 : chi + 1]) / S[0] |
| 114 | + |
111 | 115 | # Truncate the singular values |
112 | 116 | S = S[:chi] |
113 | 117 | U = U[:, :chi] |
114 | 118 | Vh = Vh[:chi, :] |
115 | 119 |
|
| 120 | + if len(S) > chi: |
| 121 | + |
| 122 | + def fix_multiplets(carry, x): |
| 123 | + S_elem, gap = x |
| 124 | + (already_found,) = carry |
| 125 | + |
| 126 | + trunc_cond = gap > truncation_eps |
| 127 | + already_found = jnp.logical_or(trunc_cond, already_found) |
| 128 | + |
| 129 | + result = cond( |
| 130 | + already_found, lambda x: x, lambda x: jnp.zeros_like(x), S_elem |
| 131 | + ) |
| 132 | + |
| 133 | + return (already_found,), result |
| 134 | + |
| 135 | + _, S = scan( |
| 136 | + fix_multiplets, |
| 137 | + (jnp.zeros((), dtype=bool),), |
| 138 | + (S, gaps), |
| 139 | + reverse=True, |
| 140 | + ) |
| 141 | + |
116 | 142 | relevant_S_values = (S / S[0]) > truncation_eps |
117 | 143 | S_inv_sqrt = jnp.where( |
118 | 144 | relevant_S_values, 1 / jnp.sqrt(jnp.where(relevant_S_values, S, 1)), 0 |
|
0 commit comments