Skip to content

Commit c9fb181

Browse files
committed
Implement GMRES-based method to calculate gradient of CTMRG routine
1 parent 4a852c1 commit c9fb181

File tree

2 files changed

+186
-8
lines changed

2 files changed

+186
-8
lines changed

varipeps/ctmrg/routine.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,16 +985,47 @@ def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, s
985985
)[1]
986986
)
987987

988-
def cond_func(carry):
989-
_, _, _, converged, count, config, state = carry
988+
old_method = False
990989

991-
return jnp.logical_not(converged) & (count < config.ad_custom_max_steps)
990+
if old_method:
991+
def cond_func(carry):
992+
_, _, _, converged, count, config, state = carry
992993

993-
_, _, env_fixed_point, converged, end_count, _, _ = while_loop(
994-
cond_func,
995-
_ctmrg_rev_while_body,
996-
(vjp_env, new_unitcell_bar, new_unitcell_bar, False, 0, config, state),
997-
)
994+
return jnp.logical_not(converged) & (count < config.ad_custom_max_steps)
995+
996+
_, _, env_fixed_point, converged, end_count, _, _ = while_loop(
997+
cond_func,
998+
_ctmrg_rev_while_body,
999+
(vjp_env, new_unitcell_bar, new_unitcell_bar, False, 0, config, state),
1000+
)
1001+
else:
1002+
def f(w):
1003+
new_w = vjp_env((w, jnp.array(0, dtype=jnp.float64)))[0]
1004+
1005+
new_w = new_w.replace_unique_tensors(
1006+
[
1007+
t_old.__sub__(t_new, checks=False)
1008+
for t_old, t_new in zip(
1009+
w.get_unique_tensors(),
1010+
new_w.get_unique_tensors(),
1011+
strict=True,
1012+
)
1013+
]
1014+
)
1015+
1016+
return new_w
1017+
1018+
is_gpu = jax.default_backend() == "gpu"
1019+
1020+
env_fixed_point, end_count = jax.scipy.sparse.linalg.gmres(
1021+
f,
1022+
new_unitcell_bar,
1023+
new_unitcell_bar,
1024+
solve_method="batched" if is_gpu else "incremental",
1025+
atol=config.ad_custom_convergence_eps,
1026+
)
1027+
1028+
converged = True
9981029

9991030
(t_bar,) = vjp_peps_tensors((env_fixed_point, jnp.array(0, dtype=jnp.float64)))
10001031

varipeps/peps/tensor.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,50 @@ def __add__(
10211021
tensor_conj=self.tensor_conj,
10221022
)
10231023

1024+
def __sub__(
1025+
self: T_PEPS_Tensor, other: T_PEPS_Tensor, *, checks: bool = True
1026+
) -> T_PEPS_Tensor:
1027+
"""
1028+
Subtract the environment tensors of two PEPS tensors.
1029+
1030+
Args:
1031+
other (:obj:`~varipeps.peps.PEPS_Tensor`):
1032+
Other PEPS tensor object which should be subtracted to this one.
1033+
Keyword args:
1034+
checks (:obj:`bool`):
1035+
Enable checks that the addition of the two tensor objects makes
1036+
sense. Maybe disabled for jax transformations.
1037+
Returns:
1038+
:obj:`~varipeps.peps.PEPS_Tensor`:
1039+
New instance with the subtracted env tensors.
1040+
"""
1041+
if checks and (
1042+
self.tensor is not other.tensor
1043+
or self.d != other.d
1044+
or self.D != other.D
1045+
or self.chi != other.chi
1046+
):
1047+
raise ValueError(
1048+
"Both PEPS tensors must have the same tensor, d, D and chi values."
1049+
)
1050+
1051+
return PEPS_Tensor(
1052+
tensor=self.tensor,
1053+
C1=self.C1 - other.C1,
1054+
C2=self.C2 - other.C2,
1055+
C3=self.C3 - other.C3,
1056+
C4=self.C4 - other.C4,
1057+
T1=self.T1 - other.T1,
1058+
T2=self.T2 - other.T2,
1059+
T3=self.T3 - other.T3,
1060+
T4=self.T4 - other.T4,
1061+
d=self.d,
1062+
D=self.D,
1063+
chi=self.chi,
1064+
max_chi=self.max_chi,
1065+
tensor_conj=self.tensor_conj,
1066+
)
1067+
10241068
@classmethod
10251069
def zeros_like(cls: Type[T_PEPS_Tensor], t: T_PEPS_Tensor) -> T_PEPS_Tensor:
10261070
"""
@@ -2683,6 +2727,58 @@ def __add__(
26832727
tensor_conj=self.tensor_conj,
26842728
)
26852729

2730+
def __sub__(
2731+
self: T_PEPS_Tensor_Split_Transfer,
2732+
other: T_PEPS_Tensor_Split_Transfer,
2733+
*,
2734+
checks: bool = True,
2735+
) -> T_PEPS_Tensor_Split_Transfer:
2736+
"""
2737+
Subtract the environment tensors of two PEPS tensors.
2738+
2739+
Args:
2740+
other (:obj:`~peps_ad.peps.PEPS_Tensor_Split_Transfer`):
2741+
Other PEPS tensor object which should be subtracted to this one.
2742+
Keyword args:
2743+
checks (:obj:`bool`):
2744+
Enable checks that the addition of the two tensor objects makes
2745+
sense. Maybe disabled for jax transformations.
2746+
Returns:
2747+
:obj:`~peps_ad.peps.PEPS_Tensor_Split_Transfer`:
2748+
New instance with the subtracted env tensors.
2749+
"""
2750+
if checks and (
2751+
self.tensor is not other.tensor
2752+
or self.d != other.d
2753+
or self.D != other.D
2754+
or self.chi != other.chi
2755+
):
2756+
raise ValueError(
2757+
"Both PEPS tensors must have the same tensor, d, D and chi values."
2758+
)
2759+
2760+
return type(self)(
2761+
tensor=self.tensor,
2762+
C1=self.C1 - other.C1,
2763+
C2=self.C2 - other.C2,
2764+
C3=self.C3 - other.C3,
2765+
C4=self.C4 - other.C4,
2766+
T1_ket=self.T1_ket - other.T1_ket,
2767+
T1_bra=self.T1_bra - other.T1_bra,
2768+
T2_ket=self.T2_ket - other.T2_ket,
2769+
T2_bra=self.T2_bra - other.T2_bra,
2770+
T3_ket=self.T3_ket - other.T3_ket,
2771+
T3_bra=self.T3_bra - other.T3_bra,
2772+
T4_ket=self.T4_ket - other.T4_ket,
2773+
T4_bra=self.T4_bra - other.T4_bra,
2774+
d=self.d,
2775+
D=self.D,
2776+
chi=self.chi,
2777+
max_chi=self.max_chi,
2778+
interlayer_chi=self.interlayer_chi,
2779+
tensor_conj=self.tensor_conj,
2780+
)
2781+
26862782
@classmethod
26872783
def zeros_like(
26882784
cls: Type[T_PEPS_Tensor_Split_Transfer], t: T_PEPS_Tensor_Split_Transfer
@@ -3671,6 +3767,57 @@ def __add__(self, other, *, checks: bool = True):
36713767
max_chi=self.max_chi,
36723768
)
36733769

3770+
def __sub__(self, other, *, checks: bool = True):
3771+
"""
3772+
Subtract the environment tensors of two PEPS tensors.
3773+
3774+
Args:
3775+
other (:obj:`~varipeps.peps.PEPS_Tensor_Triangular`):
3776+
Other PEPS tensor object which should be subtracted to this one.
3777+
Keyword args:
3778+
checks (:obj:`bool`):
3779+
Enable checks that the addition of the two tensor objects makes
3780+
sense. Maybe disabled for jax transformations.
3781+
Returns:
3782+
:obj:`~varipeps.peps.PEPS_Tensor_Triangular`:
3783+
New instance with the subtracted env tensors.
3784+
"""
3785+
if checks and (
3786+
self.tensor is not other.tensor
3787+
or self.d != other.d
3788+
or self.D != other.D
3789+
or self.chi != other.chi
3790+
):
3791+
raise ValueError(
3792+
"Both PEPS tensors must have the same tensor, d, D and chi values."
3793+
)
3794+
3795+
return type(self)(
3796+
tensor=self.tensor,
3797+
C1=self.C1 - other.C1,
3798+
C2=self.C2 - other.C2,
3799+
C3=self.C3 - other.C3,
3800+
C4=self.C4 - other.C4,
3801+
C5=self.C5 - other.C5,
3802+
C6=self.C6 - other.C6,
3803+
T1a=self.T1a - other.T1a,
3804+
T1b=self.T1b - other.T1b,
3805+
T2a=self.T2a - other.T2a,
3806+
T2b=self.T2b - other.T2b,
3807+
T3a=self.T3a - other.T3a,
3808+
T3b=self.T3b - other.T3b,
3809+
T4a=self.T4a - other.T4a,
3810+
T4b=self.T4b - other.T4b,
3811+
T5a=self.T5a - other.T5a,
3812+
T5b=self.T5b - other.T5b,
3813+
T6a=self.T6a - other.T6a,
3814+
T6b=self.T6b - other.T6b,
3815+
d=self.d,
3816+
D=self.D,
3817+
chi=self.chi,
3818+
max_chi=self.max_chi,
3819+
)
3820+
36743821
@classmethod
36753822
def zeros_like(cls, t):
36763823
"""

0 commit comments

Comments
 (0)