Skip to content

Commit 624172b

Browse files
authored
Version 4.1.0 (#85)
* debug 'to' function; resolve conflict * debug photonic UAnyGate * remove old sample_sc_mcmc * debug omega matrix * update pyproject * debug quadrature_to_ladder and ladder_to_quadrature * update pytest * update * version 4.1.0
1 parent 077808e commit 624172b

16 files changed

Lines changed: 154 additions & 101 deletions

src/deepquantum/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
DeepQuantum can be directly imported.
44
"""
55

6-
__version__ = '4.0.0'
6+
__version__ = '4.1.0'
77

88

99
from . import ansatz

src/deepquantum/circuit.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,14 @@ def __add__(self, rhs: 'QubitCircuit') -> 'QubitCircuit':
119119
def to(self, arg: Any) -> 'QubitCircuit':
120120
"""Set dtype or device of the ``QubitCircuit``."""
121121
self.init_state.to(arg)
122-
self.operators.to(arg)
123-
self.observables.to(arg)
122+
if arg in (torch.float, torch.double):
123+
for op in self.operators:
124+
op.to(arg)
125+
for ob in self.observables:
126+
ob.to(arg)
127+
else:
128+
self.operators.to(arg)
129+
self.observables.to(arg)
124130
return self
125131

126132
# pylint: disable=arguments-renamed

src/deepquantum/operation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ def __init__(
411411
# MBQC
412412
self.nodes = copy(self.wires)
413413

414+
def to(self, arg: Any) -> 'Layer':
415+
"""Set dtype or device of the ``Layer``."""
416+
for gate in self.gates:
417+
gate.to(arg)
418+
return self
419+
414420
def get_unitary(self) -> torch.Tensor:
415421
"""Get the global unitary matrix."""
416422
u = None

src/deepquantum/photonic/circuit.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ def set_init_state(self, init_state: Any) -> None:
132132
if self.mps:
133133
assert self.backend == 'fock' and not self.basis, \
134134
'Only support MPS for Fock backend with Fock state tensor.'
135-
assert self.cutoff is not None, \
136-
'Please set the cutoff.'
135+
assert self.cutoff is not None, 'Please set the cutoff.'
137136
self.init_state = MatrixProductState(nsite=self.nmode, state=init_state, chi=self.chi,
138137
qudit=self.cutoff, normalize=False)
139138
else:
@@ -191,8 +190,14 @@ def __add__(self, rhs: 'QumodeCircuit') -> 'QumodeCircuit':
191190
def to(self, arg: Any) -> 'QumodeCircuit':
192191
"""Set dtype or device of the ``QumodeCircuit``."""
193192
self.init_state.to(arg)
194-
self.operators.to(arg)
195-
self.measurements.to(arg)
193+
if arg in (torch.float, torch.double):
194+
for op in self.operators:
195+
op.to(arg)
196+
for op_m in self.measurements:
197+
op_m.to(arg)
198+
else:
199+
self.operators.to(arg)
200+
self.measurements.to(arg)
196201
if self.backend == 'bosonic' and isinstance(self._bosonic_states, list):
197202
for bs in self._bosonic_states:
198203
bs.to(arg)
@@ -1058,12 +1063,8 @@ def _get_prob_gaussian_base(
10581063
prob = p_vac * torontonian(sub_mat, sub_gamma)
10591064
return abs(prob.real).squeeze()
10601065

1061-
def _get_prob_mps(
1062-
self,
1063-
final_state: Any,
1064-
wires: Union[int, List[int], None] = None
1065-
) -> torch.Tensor:
1066-
"""Get the probability for the given bit string in mps mode.
1066+
def _get_prob_mps(self, final_state: Any, wires: Union[int, List[int], None] = None) -> torch.Tensor:
1067+
"""Get the probability of the given bit string for MPS.
10671068
10681069
Args:
10691070
final_state (Any): The final Fock basis state.
@@ -1077,13 +1078,11 @@ def _get_prob_mps(
10771078
else:
10781079
wires = self._convert_indices(wires)
10791080
assert len(final_state) == len(wires)
1080-
idx = 0
10811081
state = copy(self.state)
10821082
if self.state[0].ndim == 3:
10831083
state = [site.unsqueeze(0) for site in state]
1084-
for i in wires:
1085-
state[i] = state[i][:, :, [final_state[idx]], :]
1086-
idx += 1
1084+
for i, wire in enumerate(wires):
1085+
state[wire] = state[wire][..., [final_state[i]], :]
10871086
return inner_product_mps(state, state).real
10881087

10891088
def measure(
@@ -1145,6 +1144,7 @@ def _measure_fock(
11451144
assert not mcmc, "Final states have been calculated, we don't need mcmc!"
11461145
return self._measure_dict(shots, with_prob, wires)
11471146
elif isinstance(self.state, List):
1147+
assert not mcmc, "Final states have been calculated, we don't need mcmc!"
11481148
return self._measure_mps(shots, with_prob, wires)
11491149
else:
11501150
assert False, 'Check your forward function or input!'
@@ -1305,7 +1305,7 @@ def _measure_mps(
13051305
with_prob: bool = False,
13061306
wires: Union[int, List[int], None] = None
13071307
) -> List[Dict]:
1308-
"""Measure the final state according to mps state."""
1308+
"""Measure the final state according to MPS."""
13091309
if wires is None:
13101310
wires = self.wires
13111311
wires = sorted(self._convert_indices(wires))
@@ -1430,29 +1430,30 @@ def _generate_rand_sample(self, detector: str = 'pnrd'):
14301430
sample = torch.randint(0, self.cutoff, [nmode])
14311431
return sample
14321432

1433-
def _generate_chain_sample(
1434-
self,
1435-
wires: Union[int, List[int], None] = None
1436-
) -> torch.Tensor:
1437-
"""Generate random sample via chain rule.
1433+
def _generate_chain_sample(self, wires: Union[int, List[int], None] = None) -> torch.Tensor:
1434+
"""Generate random samples via chain rule.
1435+
1436+
Args:
1437+
wires (int, List[int] or None, optional): The wires to measure. It can be an integer or a list of
1438+
integers specifying the indices of the wires. Default: ``None`` (which means all wires are
1439+
measured)
14381440
14391441
Returns:
1440-
torch.Tensor: sample tensor of shape (batch, len(wires)).
1442+
torch.Tensor: Tensor of shape (batch, nwire).
14411443
"""
14421444
if wires is None:
14431445
wires = self.wires
14441446
wires = sorted(self._convert_indices(wires))
14451447
sample = []
1446-
mps_state = copy(self.state)
1447-
if mps_state[0].ndim == 3:
1448-
mps_state = [site.unsqueeze(0) for site in mps_state]
1448+
mps = copy(self.state)
1449+
if mps[0].ndim == 3:
1450+
mps = [site.unsqueeze(0) for site in mps]
14491451
for i in wires:
1450-
p = vmap(get_prob_mps)(mps_state, wire=i)
1452+
p = vmap(get_prob_mps)(mps, wire=i)
14511453
sample_single_wire = torch.multinomial(p, num_samples=1)
14521454
sample.append(sample_single_wire)
1453-
index = sample_single_wire.reshape(-1).view(-1, 1, 1, 1)\
1454-
.expand(-1, mps_state[i].shape[-3], -1, mps_state[i].shape[-1])
1455-
mps_state[i] = torch.gather(mps_state[i], dim=2, index=index)
1455+
index = sample_single_wire.reshape(-1, 1, 1, 1).expand(-1, mps[i].shape[-3], -1, mps[i].shape[-1])
1456+
mps[i] = torch.gather(mps[i], dim=2, index=index)
14561457
sample = torch.stack(sample, dim=-1).squeeze(1)
14571458
return sample
14581459

src/deepquantum/photonic/gate.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> torch.Tensor:
164164

165165
def get_matrix(self, theta: Any) -> torch.Tensor:
166166
"""Get the local unitary matrix acting on creation operators."""
167+
# correspond to: U a^+ U^+ = u^T @ a^+
167168
theta = self.inputs_to_tensor(theta)
168169
if self.inv_mode:
169170
theta = -theta
@@ -309,6 +310,7 @@ def _add_noise(self, theta: torch.Tensor, phi: torch.Tensor) -> Tuple[torch.Tens
309310

310311
def get_matrix(self, theta: Any, phi: Any) -> torch.Tensor:
311312
"""Get the local unitary matrix acting on creation operators."""
313+
# correspond to: U a^+ U^+ = u^T @ a^+
312314
theta, phi = self.inputs_to_tensor([theta, phi])
313315
cos = torch.cos(theta)
314316
sin = torch.sin(theta)
@@ -359,7 +361,6 @@ def get_transform_xp(self, theta: Any, phi: Any) -> Tuple[torch.Tensor, torch.Te
359361
# correspond to: U a U^+ = (u^*)^T @ a and U^+ a^+ U = u^* @ a^+
360362
matrix_xp = torch.cat([torch.cat([matrix.real, -matrix.imag], dim=-1),
361363
torch.cat([matrix.imag, matrix.real], dim=-1)], dim=-2).reshape(4, 4)
362-
matrix_xp = matrix_xp.to(theta.device, theta.dtype)
363364
vector_xp = torch.zeros(4, 1, dtype=theta.dtype, device=theta.device)
364365
return matrix_xp, vector_xp
365366

@@ -472,7 +473,8 @@ def __init__(
472473
self.name = 'MZI'
473474

474475
def get_matrix(self, theta: Any, phi: Any) -> torch.Tensor:
475-
"""Get the local unitary matrix acting on operators."""
476+
"""Get the local unitary matrix acting on creation operators."""
477+
# correspond to: U a^+ U^+ = u^T @ a^+
476478
theta, phi = self.inputs_to_tensor([theta, phi])
477479
cos = torch.cos(theta / 2)
478480
sin = torch.sin(theta / 2)
@@ -765,6 +767,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> torch.Tensor:
765767

766768
def get_matrix(self, theta: Any) -> torch.Tensor:
767769
"""Get the local unitary matrix acting on creation operators."""
770+
# correspond to: U a^+ U^+ = u^T @ a^+
768771
theta = self.inputs_to_tensor(theta)
769772
cos = torch.cos(theta / 2) + 0j
770773
sin = torch.sin(theta / 2) + 0j
@@ -848,8 +851,6 @@ def __init__(
848851
wires = list(range(minmax[0], minmax[1] + 1))
849852
super().__init__(name=name, nmode=nmode, wires=wires, cutoff=cutoff, den_mat=den_mat, noise=False)
850853
self.minmax = [min(self.wires), max(self.wires)]
851-
# for i in range(len(self.wires) - 1):
852-
# assert self.wires[i] + 1 == self.wires[i + 1], 'The wires should be consecutive integers'
853854
if not isinstance(unitary, torch.Tensor):
854855
unitary = torch.tensor(unitary, dtype=torch.cfloat).reshape(-1, len(self.wires))
855856
assert unitary.dtype in (torch.cfloat, torch.cdouble)
@@ -879,7 +880,7 @@ def get_matrix_state(self, matrix: torch.Tensor) -> torch.Tensor:
879880
"""
880881
nt = len(self.wires)
881882
sqrt = torch.sqrt(torch.arange(self.cutoff, dtype=torch.double, device=matrix.device))
882-
tran_mat = matrix.new_zeros([self.cutoff] * 2 * nt)
883+
tran_mat = matrix.new_zeros([self.cutoff] * 2 * nt)
883884
tran_mat[tuple([0] * 2 * nt)] = 1.0
884885
for rank in range(nt + 1, 2 * nt + 1):
885886
col_num = rank - nt - 1
@@ -916,10 +917,10 @@ def get_transform_xp(self, matrix: torch.Tensor) -> Tuple[torch.Tensor, torch.Te
916917
"""Get the local affine symplectic transformation acting on quadrature operators in ``xxpp`` order."""
917918
# correspond to: U a^+ U^+ = u^T @ a^+ and U^+ a U = u @ a
918919
# correspond to: U a U^+ = (u^*)^T @ a and U^+ a^+ U = u^* @ a^+
920+
n = len(self.wires)
919921
matrix_xp = torch.cat([torch.cat([matrix.real, -matrix.imag], dim=-1),
920-
torch.cat([matrix.imag, matrix.real], dim=-1)], dim=-2)
921-
matrix_xp = matrix_xp.reshape(2 * self.nmode, 2 * self.nmode)
922-
vector_xp = torch.zeros(2 * self.nmode, 1, dtype=matrix.real.dtype, device=matrix.real.device)
922+
torch.cat([matrix.imag, matrix.real], dim=-1)], dim=-2).reshape(2 * n, 2 * n)
923+
vector_xp = torch.zeros(2 * n, 1, dtype=matrix.real.dtype, device=matrix.real.device)
923924
return matrix_xp, vector_xp
924925

925926
def update_transform_xp(self) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -1000,6 +1001,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> Tuple[torch.Tensor, torch.Tens
10001001

10011002
def get_matrix(self, r: Any, theta: Any) -> torch.Tensor:
10021003
"""Get the local matrix acting on annihilation and creation operators."""
1004+
# correspond to: U^+ (a a^+) U = u @ (a a^+)
10031005
r, theta = self.inputs_to_tensor([r, theta])
10041006
ch = torch.cosh(r)
10051007
sh = torch.sinh(r)
@@ -1155,6 +1157,7 @@ def inputs_to_tensor(self, inputs: Any = None) -> Tuple[torch.Tensor, torch.Tens
11551157

11561158
def get_matrix(self, r: Any, theta: Any) -> torch.Tensor:
11571159
"""Get the local matrix acting on annihilation and creation operators."""
1160+
# correspond to: U^+ (a a^+) U = u @ (a a^+)
11581161
r, theta = self.inputs_to_tensor([r, theta])
11591162
ch = torch.cosh(r)
11601163
sh = torch.sinh(r)
@@ -1321,6 +1324,7 @@ def _add_noise(self, r: torch.Tensor, theta: torch.Tensor) -> Tuple[torch.Tensor
13211324

13221325
def get_matrix(self, r: Any, theta: Any) -> torch.Tensor:
13231326
"""Get the local unitary matrix acting on annihilation and creation operators."""
1327+
# correspond to: U^+ (a a^+) U = u @ (a a^+)
13241328
r, theta = self.inputs_to_tensor([r, theta])
13251329
return torch.eye(2, dtype=r.dtype, device=r.device) + 0j
13261330

src/deepquantum/photonic/operation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,12 @@ def __init__(
372372
self.ntau = ntau
373373
self.gates = nn.Sequential()
374374

375+
def to(self, arg: Any) -> 'Delay':
376+
"""Set dtype or device of the ``Delay``."""
377+
for gate in self.gates:
378+
gate.to(arg)
379+
return self
380+
375381
def init_para(self, inputs: Any = None) -> None:
376382
"""Initialize the parameters."""
377383
count = 0

src/deepquantum/photonic/qmath.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@
44

55
import itertools
66
import warnings
7-
from collections import defaultdict
8-
from copy import deepcopy
9-
from typing import Callable, Dict, Generator, List, Optional, Tuple
7+
from typing import Dict, Generator, List, Optional, Tuple
108

11-
import numpy as np
129
import torch
1310
from torch import vmap
1411
from torch.distributions.multivariate_normal import MultivariateNormal
15-
from tqdm import tqdm
1612

1713
import deepquantum.photonic as dqp
1814
from ..qmath import is_unitary
@@ -199,7 +195,7 @@ def xxpp_to_xpxp(matrix: torch.Tensor) -> torch.Tensor:
199195
"""Transform the representation in ``xxpp`` ordering to the representation in ``xpxp`` ordering."""
200196
nmode = matrix.shape[-2] // 2
201197
# transformation matrix
202-
t = torch.zeros([2 * nmode] * 2, dtype=matrix.dtype, device=matrix.device)
198+
t = matrix.new_zeros([2 * nmode] * 2)
203199
for i in range(2 * nmode):
204200
if i % 2 == 0:
205201
t[i][i // 2] = 1
@@ -216,7 +212,7 @@ def xpxp_to_xxpp(matrix: torch.Tensor) -> torch.Tensor:
216212
"""Transform the representation in ``xpxp`` ordering to the representation in ``xxpp`` ordering."""
217213
nmode = matrix.shape[-2] // 2
218214
# transformation matrix
219-
t = torch.zeros([2 * nmode] * 2, dtype=matrix.dtype, device=matrix.device)
215+
t = matrix.new_zeros([2 * nmode] * 2)
220216
for i in range(2 * nmode):
221217
if i < nmode:
222218
t[i][2 * i] = 1
@@ -229,30 +225,48 @@ def xpxp_to_xxpp(matrix: torch.Tensor) -> torch.Tensor:
229225
return t @ matrix
230226

231227

232-
def quadrature_to_ladder(matrix: torch.Tensor) -> torch.Tensor:
233-
"""Transform the representation in ``xxpp`` ordering to the representation in ``aa^+`` ordering."""
234-
nmode = matrix.shape[-2] // 2
235-
matrix = matrix + 0j
236-
identity = torch.eye(nmode, dtype=matrix.dtype, device=matrix.device)
228+
def quadrature_to_ladder(tensor: torch.Tensor, symplectic: bool = False) -> torch.Tensor:
229+
"""Transform the representation in ``xxpp`` ordering to the representation in ``aaa^+a^+`` ordering.
230+
231+
Args:
232+
tensor (torch.Tensor): The input tensor in ``xxpp`` ordering.
233+
symplectic (bool, optional): Whether the transformation is applied for symplectic matrix or Gaussian state.
234+
Default: ``False`` (which means covariance matrix or displacement vector)
235+
"""
236+
nmode = tensor.shape[-2] // 2
237+
tensor = tensor + 0j
238+
identity = torch.eye(nmode, dtype=tensor.dtype, device=tensor.device)
237239
omega = torch.cat([torch.cat([identity, identity * 1j], dim=-1),
238-
torch.cat([identity, identity * -1j], dim=-1)]) * dqp.kappa / dqp.hbar ** 0.5
239-
if matrix.shape[-1] == 2 * nmode:
240-
return omega @ matrix @ omega.mH
241-
elif matrix.shape[-1] == 1:
242-
return omega @ matrix
240+
torch.cat([identity, identity * -1j], dim=-1)])
241+
if tensor.shape[-1] == 2 * nmode:
242+
if symplectic:
243+
return omega @ tensor @ omega.mH / 2 # inversed omega
244+
else:
245+
return omega @ tensor @ omega.mH * dqp.kappa**2 / dqp.hbar
246+
elif tensor.shape[-1] == 1:
247+
return omega @ tensor * dqp.kappa / dqp.hbar**0.5
243248

244249

245-
def ladder_to_quadrature(matrix: torch.Tensor) -> torch.Tensor:
246-
"""Transform the representation in ``aa^+`` ordering to the representation in ``xxpp`` ordering."""
247-
nmode = matrix.shape[-2] // 2
248-
matrix = matrix + 0j
249-
identity = torch.eye(nmode, dtype=matrix.dtype, device=matrix.device)
250+
def ladder_to_quadrature(tensor: torch.Tensor, symplectic: bool = False) -> torch.Tensor:
251+
"""Transform the representation in ``aaa^+a^+`` ordering to the representation in ``xxpp`` ordering.
252+
253+
Args:
254+
tensor (torch.Tensor): The input tensor in ``aaa^+a^+`` ordering.
255+
symplectic (bool, optional): Whether the transformation is applied for symplectic matrix or Gaussian state.
256+
Default: ``False`` (which means covariance matrix or displacement vector)
257+
"""
258+
nmode = tensor.shape[-2] // 2
259+
tensor = tensor + 0j
260+
identity = torch.eye(nmode, dtype=tensor.dtype, device=tensor.device)
250261
omega = torch.cat([torch.cat([identity, identity], dim=-1),
251-
torch.cat([identity * -1j, identity * 1j], dim=-1)]) * dqp.hbar ** 0.5 / (2 * dqp.kappa)
252-
if matrix.shape[-1] == 2 * nmode:
253-
return (omega @ matrix @ omega.mH).real
254-
elif matrix.shape[-1] == 1:
255-
return (omega @ matrix).real
262+
torch.cat([identity * -1j, identity * 1j], dim=-1)])
263+
if tensor.shape[-1] == 2 * nmode:
264+
if symplectic:
265+
return (omega @ tensor @ omega.mH).real / 2 # inversed omega
266+
else:
267+
return (omega @ tensor @ omega.mH).real * dqp.hbar / (4 * dqp.kappa**2)
268+
elif tensor.shape[-1] == 1:
269+
return (omega @ tensor).real * dqp.hbar**0.5 / (2 * dqp.kappa)
256270

257271

258272
def _photon_number_mean_var_gaussian(cov: torch.Tensor, mean: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)