@@ -199,6 +199,7 @@ def do_absorption_step_triangular(
199199 [working_tensor_obj ],
200200 [T1_proj_left , T1_proj_right ],
201201 )
202+ new_T1 /= jnp .linalg .norm (new_T1 )
202203
203204 new_T1_matrix = new_T1 .reshape (
204205 new_T1 .shape [0 ] * new_T1 .shape [1 ] * new_T1 .shape [2 ],
@@ -231,6 +232,7 @@ def do_absorption_step_triangular(
231232 [working_tensor_obj ],
232233 [T2_proj_left , T2_proj_right ],
233234 )
235+ new_T2 /= jnp .linalg .norm (new_T2 )
234236
235237 new_T2_matrix = new_T2 .reshape (
236238 new_T2 .shape [0 ] * new_T2 .shape [1 ] * new_T2 .shape [2 ],
@@ -263,6 +265,7 @@ def do_absorption_step_triangular(
263265 [working_tensor_obj ],
264266 [T3_proj_left , T3_proj_right ],
265267 )
268+ new_T3 /= jnp .linalg .norm (new_T3 )
266269
267270 new_T3_matrix = new_T3 .reshape (
268271 new_T3 .shape [0 ] * new_T3 .shape [1 ] * new_T3 .shape [2 ],
@@ -295,6 +298,7 @@ def do_absorption_step_triangular(
295298 [working_tensor_obj ],
296299 [T4_proj_left , T4_proj_right ],
297300 )
301+ new_T4 /= jnp .linalg .norm (new_T4 )
298302
299303 new_T4_matrix = new_T4 .reshape (
300304 new_T4 .shape [0 ] * new_T4 .shape [1 ] * new_T4 .shape [2 ],
@@ -327,6 +331,7 @@ def do_absorption_step_triangular(
327331 [working_tensor_obj ],
328332 [T5_proj_left , T5_proj_right ],
329333 )
334+ new_T5 /= jnp .linalg .norm (new_T5 )
330335
331336 new_T5_matrix = new_T5 .reshape (
332337 new_T5 .shape [0 ] * new_T5 .shape [1 ] * new_T5 .shape [2 ],
@@ -359,6 +364,7 @@ def do_absorption_step_triangular(
359364 [working_tensor_obj ],
360365 [T6_proj_left , T6_proj_right ],
361366 )
367+ new_T6 /= jnp .linalg .norm (new_T6 )
362368
363369 new_T6_matrix = new_T6 .reshape (
364370 new_T6 .shape [0 ] * new_T6 .shape [1 ] * new_T6 .shape [2 ],
0 commit comments