Coverage for jaxquantum/circuits/library/oscillator.py: 45%
76 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1"""Oscillator gates."""
3from jaxquantum.core.operators import (displace, basis, destroy, create, num,
4 identity)
5from jaxquantum.circuits.gates import Gate
6from jax.scipy.special import factorial
7import jax.numpy as jnp
8from jaxquantum import Qarray
9from jaxquantum.utils import hermgauss
12def D(N, alpha, ts=None, c_ops=None):
13 """Displacement gate.
15 Args:
16 N: Hilbert space dimension.
17 alpha: Displacement amplitude.
18 ts: Optional time array for hamiltonian simulation.
19 c_ops: Optional collapse operators.
21 Returns:
22 Displacement gate.
23 """
24 gen_Ht = None
25 if ts is not None:
26 delta_t = ts[-1] - ts[0]
27 amp = 1j * alpha / delta_t
28 a = destroy(N)
29 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag())
31 return Gate.create(
32 N,
33 name="D",
34 params={"alpha": alpha},
35 gen_U=lambda params: displace(N, params["alpha"]),
36 gen_Ht=gen_Ht,
37 ts=ts,
38 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops,
39 num_modes=1,
40 )
43def CD(N, beta, ts=None):
44 """Conditional displacement gate.
46 Args:
47 N: Hilbert space dimension.
48 beta: Conditional displacement amplitude.
49 ts: Optional time sequence for hamiltonian simulation.
51 Returns:
52 Conditional displacement gate.
53 """
54 g = basis(2, 0)
55 e = basis(2, 1)
57 gg = g @ g.dag()
58 ee = e @ e.dag()
60 gen_Ht = None
61 if ts is not None:
62 delta_t = ts[-1] - ts[0]
63 amp = 1j * beta / delta_t / 2
64 a = destroy(N)
65 gen_Ht = lambda params: lambda t: (
66 gg
67 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee
68 ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
69 )
71 return Gate.create(
72 [2, N],
73 name="CD",
74 params={"beta": beta},
75 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2))
76 + (ee ^ displace(N, -params["beta"] / 2)),
77 gen_Ht=gen_Ht,
78 ts=ts,
79 num_modes=2,
80 )
82def CR(N, theta):
83 """Conditional rotation gate.
85 Args:
86 N: Hilbert space dimension.
87 theta: Conditional rotation angle.
89 Returns:
90 Conditional rotation gate.
91 """
92 g = basis(2, 0)
93 e = basis(2, 1)
95 gg = g @ g.dag()
96 ee = e @ e.dag()
99 return Gate.create(
100 [2, N],
101 name="CR",
102 params={"theta": theta},
103 gen_U=lambda params: (gg ^ (-1.j*theta/2*destroy(N)@create(N)).expm())
104 + (ee ^ (1.j*theta/2*destroy(N)@create(N)).expm()),
105 num_modes=2,
106 )
109def _Ph_Loss_Kraus_Op(N, err_prob, l):
110 """Returns the Kraus Operators for l-photon loss.
112 Args:
113 N: Hilbert space dimension.
114 err_prob: Error probability.
115 l: Number of photons lost.
117 Returns:
118 Kraus operator for l-photon loss.
119 """
120 """ " Returns the Kraus Operators for l-photon loss with probability
121 err_prob in a Hilbert Space of size N"""
122 return (
123 jnp.sqrt(jnp.power(err_prob, l) / factorial(l))
124 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm()
125 * destroy(N).powm(l)
126 )
129def Amp_Damp(N, err_prob, max_l):
130 """Amplitude damping channel.
132 Args:
133 N: Hilbert space dimension.
134 err_prob: Error probability.
135 max_l: Maximum number of photons lost.
137 Returns:
138 Amplitude damping channel.
139 """
140 kmap = lambda params: Qarray.from_list(
141 [_Ph_Loss_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)]
142 )
143 return Gate.create(
144 N,
145 name="Amp_Damp",
146 params={"err_prob": err_prob, "max_l": max_l},
147 gen_KM=kmap,
148 num_modes=1,
149 )
152def _Ph_Gain_Kraus_Op(N, err_prob, l):
153 """Returns the Kraus Operators for l-photon gain.
155 Args:
156 N: Hilbert space dimension.
157 err_prob: Error probability.
158 l: Number of photons gained.
160 Returns:
161 Kraus operator for l-photon gain.
162 """
163 """ " Returns the Kraus Operators for l-photon gain with probability
164 err_prob in a Hilbert Space of size N"""
165 return (
166 jnp.sqrt(jnp.power(err_prob, l) / factorial(l))
167 * create(N).powm(l)
168 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm()
169 )
172def Amp_Gain(N, err_prob, max_l):
173 """Amplitude gain channel.
175 Args:
176 N: Hilbert space dimension.
177 err_prob: Error probability.
178 max_l: Maximum number of photons gained.
180 Returns:
181 Amplitude gain channel.
182 """
183 kmap = lambda params: Qarray.from_list(
184 [_Ph_Gain_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)]
185 )
186 return Gate.create(
187 N,
188 name="Amp_Gain",
189 params={"err_prob": err_prob, "max_l": max_l},
190 gen_KM=kmap,
191 num_modes=1,
192 )
195def _Thermal_Kraus_Op(N, err_prob, n_bar, l, k):
196 """Returns the Kraus Operators for a thermal channel.
198 Args:
199 N: Hilbert space dimension.
200 err_prob: Error probability.
201 n_bar: Average photon number.
202 l: Number of photons gained.
203 k: Number of photons lost.
205 Returns:
206 Kraus operator for thermal channel.
207 """
208 """ " Returns the Kraus Operators for a thermal channel with probability
209 err_prob and average photon number n_bar in a Hilbert Space of size N"""
210 return (
211 jnp.sqrt(
212 jnp.power(err_prob * (1 + n_bar), k)
213 * jnp.power(err_prob * n_bar, l)
214 / factorial(l)
215 / factorial(k)
216 )
217 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm()
218 * destroy(N).powm(k)
219 * create(N).powm(l)
220 )
223def Thermal_Ch(N, err_prob, n_bar, max_l):
224 """Thermal channel.
226 Args:
227 N: Hilbert space dimension.
228 err_prob: Error probability.
229 n_bar: Average photon number.
230 max_l: Maximum number of photons gained/lost.
232 Returns:
233 Thermal channel.
234 """
235 kmap = lambda params: Qarray.from_list(
236 [
237 _Thermal_Kraus_Op(N, err_prob, n_bar, l, k)
238 for l in range(max_l + 1)
239 for k in range(max_l + 1)
240 ]
241 )
242 return Gate.create(
243 N,
244 name="Thermal_Ch",
245 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l},
246 gen_KM=kmap,
247 num_modes=1,
248 )
251def _Dephasing_Kraus_Op(N, w, phi):
252 """ " Returns the Kraus Operators for dephasing with weight w and phase phi
253 in a Hilbert Space of size N"""
254 return (
255 jnp.sqrt(w)*(1.j*phi*num(N)).expm()
256 )
259def Dephasing_Ch(N, err_prob, max_l):
260 """Dephasing channel.
262 Args:
263 N: Hilbert space dimension.
264 err_prob: Error probability.
265 max_l: Maximum number of kraus operators.
267 Returns:
268 Dephasing channel.
269 """
271 xs, ws = hermgauss(max_l)
272 phis = jnp.sqrt(2*err_prob)*xs
273 ws = 1/jnp.sqrt(jnp.pi)*ws
275 kmap = lambda params: Qarray.from_list(
276 [_Dephasing_Kraus_Op(N, w, phi) for (w, phi) in zip(ws, phis)]
277 )
278 return Gate.create(
279 N,
280 name="Amp_Gain",
281 params={"err_prob": err_prob, "max_l": max_l},
282 gen_KM=kmap,
283 num_modes=1,
284 )
286def selfKerr(N, K):
287 """Self-Kerr interaction gate.
289 Args:
290 N: Hilbert space dimension.
291 K: Kerr coefficient.
293 Returns:
294 Self-Kerr gate.
295 """
296 a = destroy(N)
297 return Gate.create(
298 N,
299 name="selfKerr",
300 params={"Kerr": K},
301 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(),
302 num_modes=1,
303 )
306def _Reset_Deph_Kraus_Op(N, p, t_rst, chi, l, max_l):
307 """Returns the Kraus Operators for dephasing during reset.
309 Args:
310 N: Hilbert space dimension.
311 p: Reset error probability.
312 t_rst: Reset time.
313 chi: cross-Kerr strength between qubit and resonator.
314 l: Operator index.
315 max_l: Maximum number of operators.
317 Returns:
318 Kraus operator for dephasing during reset.
319 """
321 if l == 0:
322 K_0 = (basis(2, 0) @ basis(2, 0).dag()) ^ identity(N)
323 return K_0
324 if l == 1:
325 K_1 = jnp.sqrt(p) * (basis(2, 1) @ basis(2, 1).dag()) ^ (
326 -1.j * chi * t_rst * num(N)).expm()
327 return K_1
329 ls = jnp.arange(2, max_l, 1)
331 normalization_factor = (1 - p) / jnp.sum(
332 -(jnp.log(p) * p ** ((ls - 2) / (max_l - 1))) / ((max_l - 1)))
334 prefactor = (jnp.sqrt(-(jnp.log(p) * p ** ((l - 2) / (max_l - 1))) / (
335 (max_l - 1))) * jnp.sqrt(normalization_factor))
337 K_i = (
338 prefactor *
339 ((basis(2, 0) @ basis(2, 1).dag()) ^
340 (-1.j * chi * t_rst * (l - 2) / (max_l - 1) * num(N)).expm())
341 )
343 return K_i
346def Dephasing_Reset(N, p, t_rst, chi, max_l):
347 """Dephasing due to imperfect reset between a qubit and a resonator.
349 Args:
350 N: Hilbert space dimension.
351 p: Reset error probability.
352 t_rst: Reset time.
353 chi: Dephasing strength.
354 max_l: Maximum number of operators.
356 Returns:
357 Dephasing due to reset channel.
358 """
360 kmap = lambda params: Qarray.from_list(
361 [_Reset_Deph_Kraus_Op(N, p, t_rst, chi, l, max_l) for l in
362 range(max_l)]
363 )
364 return Gate.create(
365 [2, N],
366 name="Dephasing_Reset",
367 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l},
368 gen_KM=kmap,
369 num_modes=2,
370 )