Coverage for jaxquantum / circuits / library / oscillator.py: 30%
145 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 21:05 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 21:05 +0000
1"""Oscillator gates."""
3from jaxquantum.core.operators import (displace, basis, destroy, create, num,
4 identity)
5from jaxquantum.circuits.gates import Gate
6from jax.scipy.special import factorial, gammaln
7import jax.numpy as jnp
8from jaxquantum import Qarray
9from jaxquantum.utils import hermgauss
10from functools import partial
11from jax import jit
12import jax
14def diag_expm(diag_matrix):
15 """Computes expm of a diagonal matrix efficiently (O(N) instead of O(N^3))."""
16 # Extract diagonal, exponentiate elements, put back on diagonal
17 return jnp.diag(jnp.exp(jnp.diagonal(diag_matrix)))
21def D(N, alpha, ts=None, c_ops=None):
22 """Displacement gate.
24 Args:
25 N: Hilbert space dimension.
26 alpha: Displacement amplitude.
27 ts: Optional time array for hamiltonian simulation.
28 c_ops: Optional collapse operators.
30 Returns:
31 Displacement gate.
32 """
33 gen_Ht = None
34 if ts is not None:
35 delta_t = ts[-1] - ts[0]
36 amp = 1j * alpha / delta_t
37 a = destroy(N)
38 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag())
40 return Gate.create(
41 N,
42 name="D",
43 params={"alpha": alpha},
44 gen_U=lambda params: displace(N, params["alpha"]),
45 gen_Ht=gen_Ht,
46 ts=ts,
47 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops,
48 num_modes=1,
49 )
52def CD(N, beta, ts=None):
53 """Conditional displacement gate.
55 Args:
56 N: Hilbert space dimension.
57 beta: Conditional displacement amplitude.
58 ts: Optional time sequence for hamiltonian simulation.
60 Returns:
61 Conditional displacement gate.
62 """
63 g = basis(2, 0)
64 e = basis(2, 1)
66 gg = g @ g.dag()
67 ee = e @ e.dag()
69 gen_Ht = None
70 if ts is not None:
71 delta_t = ts[-1] - ts[0]
72 amp = 1j * beta / delta_t / 2
73 a = destroy(N)
74 gen_Ht = lambda params: lambda t: (
75 gg
76 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee
77 ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
78 )
80 return Gate.create(
81 [2, N],
82 name="CD",
83 params={"beta": beta},
84 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2))
85 + (ee ^ displace(N, -params["beta"] / 2)),
86 gen_Ht=gen_Ht,
87 ts=ts,
88 num_modes=2,
89 )
92def ECD(N, beta, ts=None):
93 """Echoed conditional displacement gate.
95 Args:
96 N: Hilbert space dimension.
97 beta: Conditional displacement amplitude.
98 ts: Optional time sequence for hamiltonian simulation.
100 Returns:
101 Echoed conditional displacement gate.
102 """
103 g = basis(2, 0)
104 e = basis(2, 1)
106 eg = e @ g.dag()
107 ge = g @ e.dag()
109 # gen_Ht = None
110 # if ts is not None:
111 # delta_t = ts[-1] - ts[0]
112 # amp = 1j * beta / delta_t / 2
113 # a = destroy(N)
114 # gen_Ht = lambda params: lambda t: (
115 # eg
116 # ^ (jnp.conj(amp) * a + amp * a.dag()) + ge
117 # ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
118 # )
120 return Gate.create(
121 [2, N],
122 name="ECD",
123 params={"beta": beta},
124 gen_U=lambda params: (eg ^ displace(N, params["beta"] / 2))
125 + (ge ^ displace(N, -params["beta"] / 2)),
126 gen_Ht=None,
127 ts=ts,
128 num_modes=2,
129 )
131def CR(N, theta):
132 """Conditional rotation gate.
134 Args:
135 N: Hilbert space dimension.
136 theta: Conditional rotation angle.
138 Returns:
139 Conditional rotation gate.
140 """
141 g = basis(2, 0)
142 e = basis(2, 1)
144 gg = g @ g.dag()
145 ee = e @ e.dag()
148 return Gate.create(
149 [2, N],
150 name="CR",
151 params={"theta": theta},
152 gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm())
153 + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()),
154 num_modes=2,
155 )
158# --- 2. Optimized Kernels (Using diag_expm) ---
160@partial(jax.jit, static_argnames=["N", "max_l"])
161def _Amp_Damp_Kraus_Map_JIT(N, err_prob, max_l):
162 n_op = num(N).data
163 a_op = destroy(N).data
165 a_powers = jnp.stack(
166 [jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)])
168 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
169 # FIX: Use diag_expm
170 middle_op = diag_expm(n_op * log_term)
172 def compute_op(l):
173 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1)))
174 a_pow_l = a_powers[l]
175 return prefactor * (middle_op @ a_pow_l)
177 ls = jnp.arange(max_l + 1)
178 return jax.vmap(compute_op)(ls)
180def Amp_Damp(N, err_prob, max_l):
181 kmap = lambda params: Qarray.create(
182 _Amp_Damp_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]),
183 dims=[[N], [N]],
184 bdims=(params["max_l"] + 1,)
185 )
186 return Gate.create(
187 N,
188 name="Amp_Damp",
189 params={"err_prob": err_prob, "max_l": max_l, "N": N},
190 gen_KM=kmap,
191 num_modes=1,
192 )
195@partial(jax.jit, static_argnames=["N", "max_l"])
196def _Amp_Gain_Kraus_Map_JIT(N, err_prob, max_l):
197 n_op = num(N).data
198 adag_op = create(N).data
200 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
201 # FIX: Use diag_expm
202 middle_op = diag_expm(n_op * log_term)
204 def compute_op(l):
205 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1)))
206 adag_pow_l = jnp.linalg.matrix_power(adag_op, l)
207 return prefactor * (adag_pow_l @ middle_op)
209 ls = jnp.arange(max_l + 1)
210 return jax.vmap(compute_op)(ls)
212def Amp_Gain(N, err_prob, max_l):
213 kmap = lambda params: Qarray.create(
214 _Amp_Gain_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]),
215 dims=[[N], [N]],
216 bdims=(params["max_l"] + 1,)
217 )
218 return Gate.create(
219 N,
220 name="Amp_Gain",
221 params={"err_prob": err_prob, "max_l": max_l, "N": N},
222 gen_KM=kmap,
223 num_modes=1,
224 )
227@partial(jax.jit, static_argnames=["N", "max_l"])
228def _Thermal_Ch_Kraus_Map_JIT(N, err_prob, n_bar, max_l):
229 a_op = destroy(N).data
230 adag_op = create(N).data
231 n_op = num(N).data
233 a_powers = jnp.stack([jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)])
234 adag_powers = jnp.stack([jnp.linalg.matrix_power(adag_op, i) for i in range(max_l + 1)])
236 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
237 # FIX: Use diag_expm
238 middle_op = diag_expm(n_op * log_term)
240 def compute_single_op(idx):
241 l = idx // (max_l + 1)
242 k = idx % (max_l + 1)
244 fact_l = jnp.exp(gammaln(l + 1))
245 fact_k = jnp.exp(gammaln(k + 1))
247 term_k = jnp.power(err_prob * (1.0 + n_bar), k)
248 term_l = jnp.power(err_prob * n_bar, l)
250 prefactor = jnp.sqrt( (term_k * term_l) / (fact_k * fact_l) )
251 op_k = a_powers[k]
252 op_l = adag_powers[l]
254 return prefactor * (middle_op @ op_k @ op_l)
256 indices = jnp.arange((max_l + 1)**2)
257 return jax.vmap(compute_single_op)(indices)
259def Thermal_Ch(N, err_prob, n_bar, max_l):
260 kmap = lambda params: Qarray.create(
261 _Thermal_Ch_Kraus_Map_JIT(params["N"], params["err_prob"], params["n_bar"], params["max_l"]),
262 dims=[[N], [N]],
263 bdims=((params["max_l"] + 1)**2,)
264 )
265 return Gate.create(
266 N,
267 name="Thermal_Ch",
268 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l, "N": N},
269 gen_KM=kmap,
270 num_modes=1,
271 )
274@partial(jax.jit, static_argnames=["N", "max_l"])
275def _Dephasing_Ch_Kraus_Map_JIT(N, ws, phis, max_l):
276 n_op = num(N).data
277 def compute_op(w, phi):
278 # FIX: Use diag_expm
279 op = diag_expm(1.0j * phi * n_op)
280 return jnp.sqrt(w) * op
281 return jax.vmap(compute_op)(ws, phis)
283def Dephasing_Ch(N, err_prob, max_l):
284 xs, ws_raw = hermgauss(max_l)
285 phis = jnp.sqrt(2*err_prob)*xs
286 ws = 1/jnp.sqrt(jnp.pi)*ws_raw
288 kmap = lambda params: Qarray.create(
289 _Dephasing_Ch_Kraus_Map_JIT(params["N"], ws, phis, params["max_l"]),
290 dims=[[N], [N]],
291 bdims=(params["max_l"],)
292 )
293 return Gate.create(
294 N,
295 name="Dephasing_Ch",
296 params={"err_prob": err_prob, "max_l": max_l, "N": N},
297 gen_KM=kmap,
298 num_modes=1,
299 )
302def selfKerr(N, K):
303 a = destroy(N)
304 return Gate.create(
305 N,
306 name="selfKerr",
307 params={"Kerr": K},
308 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(),
309 num_modes=1,
310 )
313@partial(jax.jit, static_argnames=["N", "max_l"])
314def _Dephasing_Reset_Kraus_Map_JIT(N, p, t_rst, chi, max_l):
315 g = basis(2, 0).data
316 e = basis(2, 1).data
317 gg = g @ jnp.conj(g.T)
318 ee = e @ jnp.conj(e.T)
319 ge = g @ jnp.conj(e.T)
321 n_op = num(N).data
322 I_N = jnp.eye(N)
324 ls_all = jnp.arange(2, max_l+1)
325 norm_terms = -(jnp.log(p) * jnp.power(p, (ls_all - 2) / (max_l - 1))) / (max_l - 1)
326 normalization_factor = (1 - p) / jnp.sum(norm_terms)
328 def compute_op(l):
329 def branch_0(_):
330 return jnp.kron(gg, I_N)
332 def branch_1(_):
333 # FIX: Use diag_expm
334 op_osc = diag_expm(-1.0j * chi * t_rst * n_op)
335 return jnp.sqrt(p) * jnp.kron(ee, op_osc)
337 def branch_rest(_):
338 term_val = -(jnp.log(p) * jnp.power(p, (l - 2) / (max_l - 1))) / (max_l - 1)
339 prefactor = jnp.sqrt(term_val) * jnp.sqrt(normalization_factor)
341 exponent = -1.0j * chi * t_rst * (l - 2) / (max_l - 1)
342 # FIX: Use diag_expm
343 op_osc = diag_expm(exponent * n_op)
344 return prefactor * jnp.kron(ge, op_osc)
346 return jax.lax.cond(
347 l == 0,
348 branch_0,
349 lambda _: jax.lax.cond(l == 1, branch_1, branch_rest, operand=None),
350 operand=None
351 )
353 ls = jnp.arange(max_l+1)
354 return jax.vmap(compute_op)(ls)
356def Dephasing_Reset(N, p, t_rst, chi, max_l):
357 kmap = lambda params: Qarray.create(
358 _Dephasing_Reset_Kraus_Map_JIT(
359 params["N"], params["p"], params["t_rst"], params["chi"], params["max_l"]
360 ),
361 dims=[[2, N], [2, N]],
362 bdims=(params["max_l"]+1,)
363 )
365 return Gate.create(
366 [2, N],
367 name="Dephasing_Reset",
368 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l, "N": N},
369 gen_KM=kmap,
370 num_modes=2,
371 )