Coverage for jaxquantum / circuits / library / oscillator.py: 31%
144 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 20:38 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 20:38 +0000
1"""Oscillator gates."""
3from jaxquantum.core.operators import (displace, basis, destroy, create, num,
4 identity)
5from jaxquantum.circuits.gates import Gate
6from jax.scipy.special import factorial, gammaln
7import jax.numpy as jnp
8from jaxquantum import Qarray
9from jaxquantum.utils import hermgauss
10from functools import partial
11from jax import jit
12import jax
14def diag_expm(diag_matrix):
15 """Computes expm of a diagonal matrix efficiently (O(N) instead of O(N^3))."""
16 # Extract diagonal, exponentiate elements, put back on diagonal
17 return jnp.diag(jnp.exp(jnp.diagonal(diag_matrix)))
21def D(N, alpha, ts=None, c_ops=None):
22 """Displacement gate.
24 Args:
25 N: Hilbert space dimension.
26 alpha: Displacement amplitude.
27 ts: Optional time array for hamiltonian simulation.
28 c_ops: Optional collapse operators.
30 Returns:
31 Displacement gate.
32 """
33 gen_Ht = None
34 if ts is not None:
35 delta_t = ts[-1] - ts[0]
36 amp = 1j * alpha / delta_t
37 a = destroy(N)
38 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag())
40 return Gate.create(
41 N,
42 name="D",
43 params={"alpha": alpha},
44 gen_U=lambda params: displace(N, params["alpha"]),
45 gen_Ht=gen_Ht,
46 ts=ts,
47 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops,
48 num_modes=1,
49 )
52def CD(N, beta, ts=None):
53 """Conditional displacement gate.
55 Args:
56 N: Hilbert space dimension.
57 beta: Conditional displacement amplitude.
58 ts: Optional time sequence for hamiltonian simulation.
60 Returns:
61 Conditional displacement gate.
62 """
63 g = basis(2, 0)
64 e = basis(2, 1)
66 gg = g @ g.dag()
67 ee = e @ e.dag()
69 gen_Ht = None
70 if ts is not None:
71 delta_t = ts[-1] - ts[0]
72 amp = 1j * beta / delta_t / 2
73 a = destroy(N)
74 gen_Ht = lambda params: lambda t: (
75 gg
76 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee
77 ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
78 )
80 return Gate.create(
81 [2, N],
82 name="CD",
83 params={"beta": beta},
84 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2))
85 + (ee ^ displace(N, -params["beta"] / 2)),
86 gen_Ht=gen_Ht,
87 ts=ts,
88 num_modes=2,
89 )
92def ECD(N, beta, ts=None):
93 """Echoed conditional displacement gate.
95 Args:
96 N: Hilbert space dimension.
97 beta: Conditional displacement amplitude.
98 ts: Optional time sequence for hamiltonian simulation.
100 Returns:
101 Echoed conditional displacement gate.
102 """
103 g = basis(2, 0)
104 e = basis(2, 1)
106 eg = e @ g.dag()
107 ge = g @ e.dag()
109 # gen_Ht = None
110 # if ts is not None:
111 # delta_t = ts[-1] - ts[0]
112 # amp = 1j * beta / delta_t / 2
113 # a = destroy(N)
114 # gen_Ht = lambda params: lambda t: (
115 # eg
116 # ^ (jnp.conj(amp) * a + amp * a.dag()) + ge
117 # ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
118 # )
120 return Gate.create(
121 [2, N],
122 name="ECD",
123 params={"beta": beta},
124 gen_U=lambda params: (eg ^ displace(N, params["beta"] / 2))
125 + (ge ^ displace(N, -params["beta"] / 2)),
126 gen_Ht=None,
127 ts=ts,
128 num_modes=2,
129 )
131def CR(N, theta):
132 """Conditional rotation gate.
134 Args:
135 N: Hilbert space dimension.
136 theta: Conditional rotation angle.
138 Returns:
139 Conditional rotation gate.
140 """
141 g = basis(2, 0)
142 e = basis(2, 1)
144 gg = g @ g.dag()
145 ee = e @ e.dag()
148 return Gate.create(
149 [2, N],
150 name="CR",
151 params={"theta": theta},
152 gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm())
153 + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()),
154 num_modes=2,
155 )
158# --- 2. Optimized Kernels (Using diag_expm) ---
160@partial(jax.jit, static_argnames=["N", "max_l"])
161def _Amp_Damp_Kraus_Map_JIT(N, err_prob, max_l):
162 n_op = num(N).data
163 a_op = destroy(N).data
165 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
166 # FIX: Use diag_expm
167 middle_op = diag_expm(n_op * log_term)
169 def compute_op(l):
170 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1)))
171 a_pow_l = jnp.linalg.matrix_power(a_op, l)
172 return prefactor * (middle_op @ a_pow_l)
174 ls = jnp.arange(max_l + 1)
175 return jax.vmap(compute_op)(ls)
177def Amp_Damp(N, err_prob, max_l):
178 kmap = lambda params: Qarray.create(
179 _Amp_Damp_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]),
180 dims=[[N], [N]],
181 bdims=(params["max_l"] + 1,)
182 )
183 return Gate.create(
184 N,
185 name="Amp_Damp",
186 params={"err_prob": err_prob, "max_l": max_l, "N": N},
187 gen_KM=kmap,
188 num_modes=1,
189 )
192@partial(jax.jit, static_argnames=["N", "max_l"])
193def _Amp_Gain_Kraus_Map_JIT(N, err_prob, max_l):
194 n_op = num(N).data
195 adag_op = create(N).data
197 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
198 # FIX: Use diag_expm
199 middle_op = diag_expm(n_op * log_term)
201 def compute_op(l):
202 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1)))
203 adag_pow_l = jnp.linalg.matrix_power(adag_op, l)
204 return prefactor * (adag_pow_l @ middle_op)
206 ls = jnp.arange(max_l + 1)
207 return jax.vmap(compute_op)(ls)
209def Amp_Gain(N, err_prob, max_l):
210 kmap = lambda params: Qarray.create(
211 _Amp_Gain_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]),
212 dims=[[N], [N]],
213 bdims=(params["max_l"] + 1,)
214 )
215 return Gate.create(
216 N,
217 name="Amp_Gain",
218 params={"err_prob": err_prob, "max_l": max_l, "N": N},
219 gen_KM=kmap,
220 num_modes=1,
221 )
224@partial(jax.jit, static_argnames=["N", "max_l"])
225def _Thermal_Ch_Kraus_Map_JIT(N, err_prob, n_bar, max_l):
226 a_op = destroy(N).data
227 adag_op = create(N).data
228 n_op = num(N).data
230 a_powers = jnp.stack([jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)])
231 adag_powers = jnp.stack([jnp.linalg.matrix_power(adag_op, i) for i in range(max_l + 1)])
233 log_term = jnp.log(jnp.sqrt(1.0 - err_prob))
234 # FIX: Use diag_expm
235 middle_op = diag_expm(n_op * log_term)
237 def compute_single_op(idx):
238 l = idx // (max_l + 1)
239 k = idx % (max_l + 1)
241 fact_l = jnp.exp(gammaln(l + 1))
242 fact_k = jnp.exp(gammaln(k + 1))
244 term_k = jnp.power(err_prob * (1.0 + n_bar), k)
245 term_l = jnp.power(err_prob * n_bar, l)
247 prefactor = jnp.sqrt( (term_k * term_l) / (fact_k * fact_l) )
248 op_k = a_powers[k]
249 op_l = adag_powers[l]
251 return prefactor * (middle_op @ op_k @ op_l)
253 indices = jnp.arange((max_l + 1)**2)
254 return jax.vmap(compute_single_op)(indices)
256def Thermal_Ch(N, err_prob, n_bar, max_l):
257 kmap = lambda params: Qarray.create(
258 _Thermal_Ch_Kraus_Map_JIT(params["N"], params["err_prob"], params["n_bar"], params["max_l"]),
259 dims=[[N], [N]],
260 bdims=((params["max_l"] + 1)**2,)
261 )
262 return Gate.create(
263 N,
264 name="Thermal_Ch",
265 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l, "N": N},
266 gen_KM=kmap,
267 num_modes=1,
268 )
271@partial(jax.jit, static_argnames=["N", "max_l"])
272def _Dephasing_Ch_Kraus_Map_JIT(N, ws, phis, max_l):
273 n_op = num(N).data
274 def compute_op(w, phi):
275 # FIX: Use diag_expm
276 op = diag_expm(1.0j * phi * n_op)
277 return jnp.sqrt(w) * op
278 return jax.vmap(compute_op)(ws, phis)
280def Dephasing_Ch(N, err_prob, max_l):
281 xs, ws_raw = hermgauss(max_l)
282 phis = jnp.sqrt(2*err_prob)*xs
283 ws = 1/jnp.sqrt(jnp.pi)*ws_raw
285 kmap = lambda params: Qarray.create(
286 _Dephasing_Ch_Kraus_Map_JIT(params["N"], ws, phis, params["max_l"]),
287 dims=[[N], [N]],
288 bdims=(params["max_l"],)
289 )
290 return Gate.create(
291 N,
292 name="Dephasing_Ch",
293 params={"err_prob": err_prob, "max_l": max_l, "N": N},
294 gen_KM=kmap,
295 num_modes=1,
296 )
299def selfKerr(N, K):
300 a = destroy(N)
301 return Gate.create(
302 N,
303 name="selfKerr",
304 params={"Kerr": K},
305 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(),
306 num_modes=1,
307 )
310@partial(jax.jit, static_argnames=["N", "max_l"])
311def _Dephasing_Reset_Kraus_Map_JIT(N, p, t_rst, chi, max_l):
312 g = basis(2, 0).data
313 e = basis(2, 1).data
314 gg = g @ jnp.conj(g.T)
315 ee = e @ jnp.conj(e.T)
316 ge = g @ jnp.conj(e.T)
318 n_op = num(N).data
319 I_N = jnp.eye(N)
321 ls_all = jnp.arange(2, max_l+1)
322 norm_terms = -(jnp.log(p) * jnp.power(p, (ls_all - 2) / (max_l - 1))) / (max_l - 1)
323 normalization_factor = (1 - p) / jnp.sum(norm_terms)
325 def compute_op(l):
326 def branch_0(_):
327 return jnp.kron(gg, I_N)
329 def branch_1(_):
330 # FIX: Use diag_expm
331 op_osc = diag_expm(-1.0j * chi * t_rst * n_op)
332 return jnp.sqrt(p) * jnp.kron(ee, op_osc)
334 def branch_rest(_):
335 term_val = -(jnp.log(p) * jnp.power(p, (l - 2) / (max_l - 1))) / (max_l - 1)
336 prefactor = jnp.sqrt(term_val) * jnp.sqrt(normalization_factor)
338 exponent = -1.0j * chi * t_rst * (l - 2) / (max_l - 1)
339 # FIX: Use diag_expm
340 op_osc = diag_expm(exponent * n_op)
341 return prefactor * jnp.kron(ge, op_osc)
343 return jax.lax.cond(
344 l == 0,
345 branch_0,
346 lambda _: jax.lax.cond(l == 1, branch_1, branch_rest, operand=None),
347 operand=None
348 )
350 ls = jnp.arange(max_l+1)
351 return jax.vmap(compute_op)(ls)
353def Dephasing_Reset(N, p, t_rst, chi, max_l):
354 kmap = lambda params: Qarray.create(
355 _Dephasing_Reset_Kraus_Map_JIT(
356 params["N"], params["p"], params["t_rst"], params["chi"], params["max_l"]
357 ),
358 dims=[[2, N], [2, N]],
359 bdims=(params["max_l"]+1,)
360 )
362 return Gate.create(
363 [2, N],
364 name="Dephasing_Reset",
365 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l, "N": N},
366 gen_KM=kmap,
367 num_modes=2,
368 )