@@ -164,6 +164,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> torch.Tensor:
164164
165165 def get_matrix (self , theta : Any ) -> torch .Tensor :
166166 """Get the local unitary matrix acting on creation operators."""
167+ # correspond to: U a^+ U^+ = u^T @ a^+
167168 theta = self .inputs_to_tensor (theta )
168169 if self .inv_mode :
169170 theta = - theta
@@ -309,6 +310,7 @@ def _add_noise(self, theta: torch.Tensor, phi: torch.Tensor) -> Tuple[torch.Tens
309310
310311 def get_matrix (self , theta : Any , phi : Any ) -> torch .Tensor :
311312 """Get the local unitary matrix acting on creation operators."""
313+ # correspond to: U a^+ U^+ = u^T @ a^+
312314 theta , phi = self .inputs_to_tensor ([theta , phi ])
313315 cos = torch .cos (theta )
314316 sin = torch .sin (theta )
@@ -359,7 +361,6 @@ def get_transform_xp(self, theta: Any, phi: Any) -> Tuple[torch.Tensor, torch.Te
359361 # correspond to: U a U^+ = (u^*)^T @ a and U^+ a^+ U = u^* @ a^+
360362 matrix_xp = torch .cat ([torch .cat ([matrix .real , - matrix .imag ], dim = - 1 ),
361363 torch .cat ([matrix .imag , matrix .real ], dim = - 1 )], dim = - 2 ).reshape (4 , 4 )
362- matrix_xp = matrix_xp .to (theta .device , theta .dtype )
363364 vector_xp = torch .zeros (4 , 1 , dtype = theta .dtype , device = theta .device )
364365 return matrix_xp , vector_xp
365366
@@ -472,7 +473,8 @@ def __init__(
472473 self .name = 'MZI'
473474
474475 def get_matrix (self , theta : Any , phi : Any ) -> torch .Tensor :
475- """Get the local unitary matrix acting on operators."""
476+ """Get the local unitary matrix acting on creation operators."""
477+ # correspond to: U a^+ U^+ = u^T @ a^+
476478 theta , phi = self .inputs_to_tensor ([theta , phi ])
477479 cos = torch .cos (theta / 2 )
478480 sin = torch .sin (theta / 2 )
@@ -765,6 +767,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> torch.Tensor:
765767
766768 def get_matrix (self , theta : Any ) -> torch .Tensor :
767769 """Get the local unitary matrix acting on creation operators."""
770+ # correspond to: U a^+ U^+ = u^T @ a^+
768771 theta = self .inputs_to_tensor (theta )
769772 cos = torch .cos (theta / 2 ) + 0j
770773 sin = torch .sin (theta / 2 ) + 0j
@@ -848,8 +851,6 @@ def __init__(
848851 wires = list (range (minmax [0 ], minmax [1 ] + 1 ))
849852 super ().__init__ (name = name , nmode = nmode , wires = wires , cutoff = cutoff , den_mat = den_mat , noise = False )
850853 self .minmax = [min (self .wires ), max (self .wires )]
851- # for i in range(len(self.wires) - 1):
852- # assert self.wires[i] + 1 == self.wires[i + 1], 'The wires should be consecutive integers'
853854 if not isinstance (unitary , torch .Tensor ):
854855 unitary = torch .tensor (unitary , dtype = torch .cfloat ).reshape (- 1 , len (self .wires ))
855856 assert unitary .dtype in (torch .cfloat , torch .cdouble )
@@ -879,7 +880,7 @@ def get_matrix_state(self, matrix: torch.Tensor) -> torch.Tensor:
879880 """
880881 nt = len (self .wires )
881882 sqrt = torch .sqrt (torch .arange (self .cutoff , dtype = torch .double , device = matrix .device ))
882- tran_mat = matrix .new_zeros ([self .cutoff ] * 2 * nt )
883+ tran_mat = matrix .new_zeros ([self .cutoff ] * 2 * nt )
883884 tran_mat [tuple ([0 ] * 2 * nt )] = 1.0
884885 for rank in range (nt + 1 , 2 * nt + 1 ):
885886 col_num = rank - nt - 1
@@ -916,10 +917,10 @@ def get_transform_xp(self, matrix: torch.Tensor) -> Tuple[torch.Tensor, torch.Te
916917 """Get the local affine symplectic transformation acting on quadrature operators in ``xxpp`` order."""
917918 # correspond to: U a^+ U^+ = u^T @ a^+ and U^+ a U = u @ a
918919 # correspond to: U a U^+ = (u^*)^T @ a and U^+ a^+ U = u^* @ a^+
920+ n = len (self .wires )
919921 matrix_xp = torch .cat ([torch .cat ([matrix .real , - matrix .imag ], dim = - 1 ),
920- torch .cat ([matrix .imag , matrix .real ], dim = - 1 )], dim = - 2 )
921- matrix_xp = matrix_xp .reshape (2 * self .nmode , 2 * self .nmode )
922- vector_xp = torch .zeros (2 * self .nmode , 1 , dtype = matrix .real .dtype , device = matrix .real .device )
922+ torch .cat ([matrix .imag , matrix .real ], dim = - 1 )], dim = - 2 ).reshape (2 * n , 2 * n )
923+ vector_xp = torch .zeros (2 * n , 1 , dtype = matrix .real .dtype , device = matrix .real .device )
923924 return matrix_xp , vector_xp
924925
925926 def update_transform_xp (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -1000,6 +1001,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> Tuple[torch.Tensor, torch.Tens
10001001
10011002 def get_matrix (self , r : Any , theta : Any ) -> torch .Tensor :
10021003 """Get the local matrix acting on annihilation and creation operators."""
1004+ # correspond to: U^+ (a a^+) U = u @ (a a^+)
10031005 r , theta = self .inputs_to_tensor ([r , theta ])
10041006 ch = torch .cosh (r )
10051007 sh = torch .sinh (r )
@@ -1155,6 +1157,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> Tuple[torch.Tensor, torch.Tens
11551157
11561158 def get_matrix (self , r : Any , theta : Any ) -> torch .Tensor :
11571159 """Get the local matrix acting on annihilation and creation operators."""
1160+ # correspond to: U^+ (a a^+) U = u @ (a a^+)
11581161 r , theta = self .inputs_to_tensor ([r , theta ])
11591162 ch = torch .cosh (r )
11601163 sh = torch .sinh (r )
@@ -1321,6 +1324,7 @@ def _add_noise(self, r: torch.Tensor, theta: torch.Tensor) -> Tuple[torch.Tensor
13211324
13221325 def get_matrix (self , r : Any , theta : Any ) -> torch .Tensor :
13231326 """Get the local unitary matrix acting on annihilation and creation operators."""
1327+ # correspond to: U^+ (a a^+) U = u @ (a a^+)
13241328 r , theta = self .inputs_to_tensor ([r , theta ])
13251329 return torch .eye (2 , dtype = r .dtype , device = r .device ) + 0j
13261330
0 commit comments