Coverage for jaxquantum / circuits / library / oscillator.py: 30%
144 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""Oscillator gates."""
3from jaxquantum.core.operators import (displace, basis, destroy, create, num)
4from jaxquantum.circuits.gates import Gate
5from jax.scipy.special import gammaln
6import jax.numpy as jnp
7from jaxquantum import Qarray
8from jaxquantum.utils import hermgauss
9from functools import partial
10import jax
12def diag_expm(diag_matrix):
13 """Computes expm of a diagonal matrix efficiently (O(N) instead of O(N^3))."""
14 # Extract diagonal, exponentiate elements, put back on diagonal
15 return jnp.diag(jnp.exp(jnp.diagonal(diag_matrix)))
19def D(N, alpha, ts=None, c_ops=None):
20 """Displacement gate.
22 Args:
23 N: Hilbert space dimension.
24 alpha: Displacement amplitude.
25 ts: Optional time array for hamiltonian simulation.
26 c_ops: Optional collapse operators.
28 Returns:
29 Displacement gate.
30 """
31 gen_Ht = None
32 if ts is not None:
33 delta_t = ts[-1] - ts[0]
34 amp = 1j * alpha / delta_t
35 a = destroy(N)
36 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag())
38 return Gate.create(
39 N,
40 name="D",
41 params={"alpha": alpha},
42 gen_U=lambda params: displace(N, params["alpha"]),
43 gen_Ht=gen_Ht,
44 ts=ts,
45 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops,
46 num_modes=1,
47 )
50def CD(N, beta, ts=None):
51 """Conditional displacement gate.
53 Args:
54 N: Hilbert space dimension.
55 beta: Conditional displacement amplitude.
56 ts: Optional time sequence for hamiltonian simulation.
58 Returns:
59 Conditional displacement gate.
60 """
61 g = basis(2, 0)
62 e = basis(2, 1)
64 gg = g @ g.dag()
65 ee = e @ e.dag()
67 gen_Ht = None
68 if ts is not None:
69 delta_t = ts[-1] - ts[0]
70 amp = 1j * beta / delta_t / 2
71 a = destroy(N)
72 gen_Ht = lambda params: lambda t: (
73 gg
74 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee
75 ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
76 )
78 return Gate.create(
79 [2, N],
80 name="CD",
81 params={"beta": beta},
82 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2))
83 + (ee ^ displace(N, -params["beta"] / 2)),
84 gen_Ht=gen_Ht,
85 ts=ts,
86 num_modes=2,
87 )
90def ECD(N, beta, ts=None):
91 """Echoed conditional displacement gate.
93 Args:
94 N: Hilbert space dimension.
95 beta: Conditional displacement amplitude.
96 ts: Optional time sequence for hamiltonian simulation.
98 Returns:
99 Echoed conditional displacement gate.
100 """
101 g = basis(2, 0)
102 e = basis(2, 1)
104 eg = e @ g.dag()
105 ge = g @ e.dag()
107 # gen_Ht = None
108 # if ts is not None:
109 # delta_t = ts[-1] - ts[0]
110 # amp = 1j * beta / delta_t / 2
111 # a = destroy(N)
112 # gen_Ht = lambda params: lambda t: (
113 # eg
114 # ^ (jnp.conj(amp) * a + amp * a.dag()) + ge
115 # ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
116 # )
118 return Gate.create(
119 [2, N],
120 name="ECD",
121 params={"beta": beta},
122 gen_U=lambda params: (eg ^ displace(N, params["beta"] / 2))
123 + (ge ^ displace(N, -params["beta"] / 2)),
124 gen_Ht=None,
125 ts=ts,
126 num_modes=2,
127 )
129def CR(N, theta):
130 """Conditional rotation gate.
132 Args:
133 N: Hilbert space dimension.
134 theta: Conditional rotation angle.
136 Returns:
137 Conditional rotation gate.
138 """
139 g = basis(2, 0)
140 e = basis(2, 1)
142 gg = g @ g.dag()
143 ee = e @ e.dag()
146 return Gate.create(
147 [2, N],
148 name="CR",
149 params={"theta": theta},
150 gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm())
151 + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()),
152 num_modes=2,
153 )
156# --- 2. Optimized Kernels (Using diag_expm) ---
158@partial(jax.jit, static_argnames=["N", "max_l"])
159def _Amp_Damp_Kraus_Map_JIT(N, err_prob, max_l):
160 n_op = num(N).data
161 a_op = destroy(N).data
163 a_powers = jnp.stack(
164 [jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)])
166 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
167 # FIX: Use diag_expm
168 middle_op = diag_expm(n_op * log_term)
170 def compute_op(l):
171 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1)))
172 a_pow_l = a_powers[l]
173 return prefactor * (middle_op @ a_pow_l)
175 ls = jnp.arange(max_l + 1)
176 return jax.vmap(compute_op)(ls)
178def Amp_Damp(N, err_prob, max_l):
179 kmap = lambda params: Qarray.create(
180 _Amp_Damp_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]),
181 dims=[[N], [N]],
182 bdims=(params["max_l"] + 1,)
183 )
184 return Gate.create(
185 N,
186 name="Amp_Damp",
187 params={"err_prob": err_prob, "max_l": max_l, "N": N},
188 gen_KM=kmap,
189 num_modes=1,
190 )
193@partial(jax.jit, static_argnames=["N", "max_l"])
194def _Amp_Gain_Kraus_Map_JIT(N, err_prob, max_l):
195 n_op = num(N).data
196 adag_op = create(N).data
198 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
199 # FIX: Use diag_expm
200 middle_op = diag_expm(n_op * log_term)
202 def compute_op(l):
203 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1)))
204 adag_pow_l = jnp.linalg.matrix_power(adag_op, l)
205 return prefactor * (adag_pow_l @ middle_op)
207 ls = jnp.arange(max_l + 1)
208 return jax.vmap(compute_op)(ls)
210def Amp_Gain(N, err_prob, max_l):
211 kmap = lambda params: Qarray.create(
212 _Amp_Gain_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]),
213 dims=[[N], [N]],
214 bdims=(params["max_l"] + 1,)
215 )
216 return Gate.create(
217 N,
218 name="Amp_Gain",
219 params={"err_prob": err_prob, "max_l": max_l, "N": N},
220 gen_KM=kmap,
221 num_modes=1,
222 )
225@partial(jax.jit, static_argnames=["N", "max_l"])
226def _Thermal_Ch_Kraus_Map_JIT(N, err_prob, n_bar, max_l):
227 a_op = destroy(N).data
228 adag_op = create(N).data
229 n_op = num(N).data
231 a_powers = jnp.stack([jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)])
232 adag_powers = jnp.stack([jnp.linalg.matrix_power(adag_op, i) for i in range(max_l + 1)])
234 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
235 # FIX: Use diag_expm
236 middle_op = diag_expm(n_op * log_term)
238 def compute_single_op(idx):
239 l = idx // (max_l + 1)
240 k = idx % (max_l + 1)
242 fact_l = jnp.exp(gammaln(l + 1))
243 fact_k = jnp.exp(gammaln(k + 1))
245 term_k = jnp.power(err_prob * (1.0 + n_bar), k)
246 term_l = jnp.power(err_prob * n_bar, l)
248 prefactor = jnp.sqrt( (term_k * term_l) / (fact_k * fact_l) )
249 op_k = a_powers[k]
250 op_l = adag_powers[l]
252 return prefactor * (middle_op @ op_k @ op_l)
254 indices = jnp.arange((max_l + 1)**2)
255 return jax.vmap(compute_single_op)(indices)
257def Thermal_Ch(N, err_prob, n_bar, max_l):
258 kmap = lambda params: Qarray.create(
259 _Thermal_Ch_Kraus_Map_JIT(params["N"], params["err_prob"], params["n_bar"], params["max_l"]),
260 dims=[[N], [N]],
261 bdims=((params["max_l"] + 1)**2,)
262 )
263 return Gate.create(
264 N,
265 name="Thermal_Ch",
266 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l, "N": N},
267 gen_KM=kmap,
268 num_modes=1,
269 )
272@partial(jax.jit, static_argnames=["N", "max_l"])
273def _Dephasing_Ch_Kraus_Map_JIT(N, ws, phis, max_l):
274 n_op = num(N).data
275 def compute_op(w, phi):
276 # FIX: Use diag_expm
277 op = diag_expm(1.0j * phi * n_op)
278 return jnp.sqrt(w) * op
279 return jax.vmap(compute_op)(ws, phis)
281def Dephasing_Ch(N, err_prob, max_l):
282 xs, ws_raw = hermgauss(max_l)
283 phis = jnp.sqrt(2*err_prob)*xs
284 ws = 1/jnp.sqrt(jnp.pi)*ws_raw
286 kmap = lambda params: Qarray.create(
287 _Dephasing_Ch_Kraus_Map_JIT(params["N"], ws, phis, params["max_l"]),
288 dims=[[N], [N]],
289 bdims=(params["max_l"],)
290 )
291 return Gate.create(
292 N,
293 name="Dephasing_Ch",
294 params={"err_prob": err_prob, "max_l": max_l, "N": N},
295 gen_KM=kmap,
296 num_modes=1,
297 )
300def selfKerr(N, K):
301 a = destroy(N)
302 return Gate.create(
303 N,
304 name="selfKerr",
305 params={"Kerr": K},
306 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(),
307 num_modes=1,
308 )
311@partial(jax.jit, static_argnames=["N", "max_l"])
312def _Dephasing_Reset_Kraus_Map_JIT(N, p, t_rst, chi, max_l):
313 g = basis(2, 0).data
314 e = basis(2, 1).data
315 gg = g @ jnp.conj(g.T)
316 ee = e @ jnp.conj(e.T)
317 ge = g @ jnp.conj(e.T)
319 n_op = num(N).data
320 I_N = jnp.eye(N)
322 ls_all = jnp.arange(2, max_l+1)
323 norm_terms = -(jnp.log(p) * jnp.power(p, (ls_all - 2) / (max_l - 1))) / (max_l - 1)
324 normalization_factor = (1 - p) / jnp.sum(norm_terms)
326 def compute_op(l):
327 def branch_0(_):
328 return jnp.kron(gg, I_N)
330 def branch_1(_):
331 # FIX: Use diag_expm
332 op_osc = diag_expm(-1.0j * chi * t_rst * n_op)
333 return jnp.sqrt(p) * jnp.kron(ee, op_osc)
335 def branch_rest(_):
336 term_val = -(jnp.log(p) * jnp.power(p, (l - 2) / (max_l - 1))) / (max_l - 1)
337 prefactor = jnp.sqrt(term_val) * jnp.sqrt(normalization_factor)
339 exponent = -1.0j * chi * t_rst * (l - 2) / (max_l - 1)
340 # FIX: Use diag_expm
341 op_osc = diag_expm(exponent * n_op)
342 return prefactor * jnp.kron(ge, op_osc)
344 return jax.lax.cond(
345 l == 0,
346 branch_0,
347 lambda _: jax.lax.cond(l == 1, branch_1, branch_rest, operand=None),
348 operand=None
349 )
351 ls = jnp.arange(max_l+1)
352 return jax.vmap(compute_op)(ls)
354def Dephasing_Reset(N, p, t_rst, chi, max_l):
355 kmap = lambda params: Qarray.create(
356 _Dephasing_Reset_Kraus_Map_JIT(
357 params["N"], params["p"], params["t_rst"], params["chi"], params["max_l"]
358 ),
359 dims=[[2, N], [2, N]],
360 bdims=(params["max_l"]+1,)
361 )
363 return Gate.create(
364 [2, N],
365 name="Dephasing_Reset",
366 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l, "N": N},
367 gen_KM=kmap,
368 num_modes=2,
369 )