11from functools import partial
22import enum
33
4+ import numpy as np
5+ from scipy .sparse .linalg import LinearOperator , eigs
6+
47import jax .numpy as jnp
58from jax import jit , custom_vjp , vjp , tree_util
69from jax .lax import cond , while_loop
@@ -945,6 +948,51 @@ def _ctmrg_rev_while_body(carry):
945948 return vjp_env , initial_bar , bar_fixed_point , converged , count , config , state
946949
947950
951+ def _ctmrg_rev_arnoldi (vjp_operator , initial_v ):
952+ v , v_treedev = jax .tree .flatten (initial_v )
953+ v_flat = np .concatenate ([i .reshape (- 1 ) for i in v ])
954+
955+ def matvec (vec ):
956+ new_vec = [None ] * len (v )
957+ for i in range (len (v )):
958+ i_start = sum (j .size for j in v [:i ])
959+ elems = slice (i_start , i_start + v [i ].size )
960+ new_vec [i ] = vec [elems ].astype (v [i ].dtype ).reshape (v [i ].shape )
961+ new_vec = jax .tree .unflatten (v_treedev , new_vec )
962+
963+ new_vec = vjp_operator ((new_vec , jnp .array (0 , dtype = jnp .float64 )))[0 ]
964+
965+ new_vec , _ = jax .tree .flatten (new_vec )
966+ new_vec = np .concatenate ([i .reshape (- 1 ) for i in new_vec ])
967+
968+ return np .append (new_vec + vec [- 1 ] * v_flat , vec [- 1 ])
969+
970+ lin_op = LinearOperator (
971+ (v_flat .shape [0 ] + 1 , v_flat .shape [0 ] + 1 ),
972+ matvec = matvec ,
973+ )
974+
975+ _ , vec = eigs (
976+ lin_op , k = 1 , v0 = np .append (v_flat , np .array (1 , dtype = v_flat .dtype )), which = "LM"
977+ )
978+
979+ vec = vec .reshape (- 1 )
980+
981+ if np .abs (vec [- 1 ]) >= 1e-10 :
982+ vec /= vec [- 1 ]
983+
984+ result = [None ] * len (v )
985+ for i in range (len (v )):
986+ i_start = sum (j .size for j in v [:i ])
987+ elems = slice (i_start , i_start + v [i ].size )
988+ result [i ] = vec [elems ].astype (v [i ].dtype ).reshape (v [i ].shape )
989+
990+ if np .abs (vec [- 1 ]) < 1e-10 :
991+ return jax .tree .unflatten (v_treedev , result ), False
992+
993+ return jax .tree .unflatten (v_treedev , result ), True
994+
995+
948996@jit
949997def _ctmrg_rev_workhorse (peps_tensors , new_unitcell , new_unitcell_bar , config , state ):
950998 if new_unitcell .is_triangular_peps ():
@@ -988,6 +1036,7 @@ def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, s
9881036 old_method = False
9891037
9901038 if old_method :
1039+
9911040 def cond_func (carry ):
9921041 _ , _ , _ , converged , count , config , state = carry
9931042
@@ -999,30 +1048,50 @@ def cond_func(carry):
9991048 (vjp_env , new_unitcell_bar , new_unitcell_bar , False , 0 , config , state ),
10001049 )
10011050 else :
1002- def f (w ):
1003- new_w = vjp_env ((w , jnp .array (0 , dtype = jnp .float64 )))[0 ]
1004-
1005- new_w = new_w .replace_unique_tensors (
1006- [
1007- t_old .__sub__ (t_new , checks = False )
1008- for t_old , t_new in zip (
1009- w .get_unique_tensors (),
1010- new_w .get_unique_tensors (),
1011- strict = True ,
1012- )
1013- ]
1014- )
1051+ env_fixed_point , arnoldi_worked = jax .pure_callback (
1052+ _ctmrg_rev_arnoldi ,
1053+ jax .eval_shape (lambda x : (x , True ), new_unitcell_bar ),
1054+ vjp (
1055+ lambda u : do_absorption_step (peps_tensors , u , config , state ),
1056+ new_unitcell ,
1057+ )[1 ],
1058+ new_unitcell_bar ,
1059+ )
1060+ end_count = 0
1061+
1062+ debug_print ("Arnoldi: {}" , arnoldi_worked )
1063+
1064+ def run_gmres (v , e ):
1065+ def f_gmres (w ):
1066+ new_w = vjp_env ((w , jnp .array (0 , dtype = jnp .float64 )))[0 ]
1067+
1068+ new_w = new_w .replace_unique_tensors (
1069+ [
1070+ t_old .__sub__ (t_new , checks = False )
1071+ for t_old , t_new in zip (
1072+ w .get_unique_tensors (),
1073+ new_w .get_unique_tensors (),
1074+ strict = True ,
1075+ )
1076+ ]
1077+ )
10151078
1016- return new_w
1079+ return new_w
10171080
1018- is_gpu = jax .default_backend () == "gpu"
1081+ is_gpu = jax .default_backend () == "gpu"
10191082
1020- env_fixed_point , end_count = jax .scipy .sparse .linalg .gmres (
1021- f ,
1022- new_unitcell_bar ,
1023- new_unitcell_bar ,
1024- solve_method = "batched" if is_gpu else "incremental" ,
1025- atol = config .ad_custom_convergence_eps ,
1083+ v , e = jax .scipy .sparse .linalg .gmres (
1084+ f_gmres ,
1085+ new_unitcell_bar ,
1086+ new_unitcell_bar ,
1087+ solve_method = "batched" if is_gpu else "incremental" ,
1088+ atol = config .ad_custom_convergence_eps ,
1089+ )
1090+
1091+ return v , e
1092+
1093+ env_fixed_point , end_count = jax .lax .cond (
1094+ arnoldi_worked , lambda x , e : (x , e ), run_gmres , env_fixed_point , end_count
10261095 )
10271096
10281097 converged = True
0 commit comments