Coverage for jaxquantum / circuits / library / oscillator.py: 30%

144 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""Oscillator gates.""" 

2 

3from jaxquantum.core.operators import (displace, basis, destroy, create, num) 

4from jaxquantum.circuits.gates import Gate 

5from jax.scipy.special import gammaln 

6import jax.numpy as jnp 

7from jaxquantum import Qarray 

8from jaxquantum.utils import hermgauss 

9from functools import partial 

10import jax 

11 

12def diag_expm(diag_matrix): 

13 """Computes expm of a diagonal matrix efficiently (O(N) instead of O(N^3)).""" 

14 # Extract diagonal, exponentiate elements, put back on diagonal 

15 return jnp.diag(jnp.exp(jnp.diagonal(diag_matrix))) 

16 

17 

18 

19def D(N, alpha, ts=None, c_ops=None): 

20 """Displacement gate. 

21 

22 Args: 

23 N: Hilbert space dimension. 

24 alpha: Displacement amplitude. 

25 ts: Optional time array for hamiltonian simulation. 

26 c_ops: Optional collapse operators. 

27 

28 Returns: 

29 Displacement gate. 

30 """ 

31 gen_Ht = None 

32 if ts is not None: 

33 delta_t = ts[-1] - ts[0] 

34 amp = 1j * alpha / delta_t 

35 a = destroy(N) 

36 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag()) 

37 

38 return Gate.create( 

39 N, 

40 name="D", 

41 params={"alpha": alpha}, 

42 gen_U=lambda params: displace(N, params["alpha"]), 

43 gen_Ht=gen_Ht, 

44 ts=ts, 

45 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops, 

46 num_modes=1, 

47 ) 

48 

49 

50def CD(N, beta, ts=None): 

51 """Conditional displacement gate. 

52 

53 Args: 

54 N: Hilbert space dimension. 

55 beta: Conditional displacement amplitude. 

56 ts: Optional time sequence for hamiltonian simulation. 

57 

58 Returns: 

59 Conditional displacement gate. 

60 """ 

61 g = basis(2, 0) 

62 e = basis(2, 1) 

63 

64 gg = g @ g.dag() 

65 ee = e @ e.dag() 

66 

67 gen_Ht = None 

68 if ts is not None: 

69 delta_t = ts[-1] - ts[0] 

70 amp = 1j * beta / delta_t / 2 

71 a = destroy(N) 

72 gen_Ht = lambda params: lambda t: ( 

73 gg 

74 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee 

75 ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

76 ) 

77 

78 return Gate.create( 

79 [2, N], 

80 name="CD", 

81 params={"beta": beta}, 

82 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2)) 

83 + (ee ^ displace(N, -params["beta"] / 2)), 

84 gen_Ht=gen_Ht, 

85 ts=ts, 

86 num_modes=2, 

87 ) 

88 

89 

90def ECD(N, beta, ts=None): 

91 """Echoed conditional displacement gate. 

92 

93 Args: 

94 N: Hilbert space dimension. 

95 beta: Conditional displacement amplitude. 

96 ts: Optional time sequence for hamiltonian simulation. 

97 

98 Returns: 

99 Echoed conditional displacement gate. 

100 """ 

101 g = basis(2, 0) 

102 e = basis(2, 1) 

103 

104 eg = e @ g.dag() 

105 ge = g @ e.dag() 

106 

107 # gen_Ht = None 

108 # if ts is not None: 

109 # delta_t = ts[-1] - ts[0] 

110 # amp = 1j * beta / delta_t / 2 

111 # a = destroy(N) 

112 # gen_Ht = lambda params: lambda t: ( 

113 # eg 

114 # ^ (jnp.conj(amp) * a + amp * a.dag()) + ge 

115 # ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

116 # ) 

117 

118 return Gate.create( 

119 [2, N], 

120 name="ECD", 

121 params={"beta": beta}, 

122 gen_U=lambda params: (eg ^ displace(N, params["beta"] / 2)) 

123 + (ge ^ displace(N, -params["beta"] / 2)), 

124 gen_Ht=None, 

125 ts=ts, 

126 num_modes=2, 

127 ) 

128 

129def CR(N, theta): 

130 """Conditional rotation gate. 

131 

132 Args: 

133 N: Hilbert space dimension. 

134 theta: Conditional rotation angle. 

135 

136 Returns: 

137 Conditional rotation gate. 

138 """ 

139 g = basis(2, 0) 

140 e = basis(2, 1) 

141 

142 gg = g @ g.dag() 

143 ee = e @ e.dag() 

144 

145 

146 return Gate.create( 

147 [2, N], 

148 name="CR", 

149 params={"theta": theta}, 

150 gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm()) 

151 + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()), 

152 num_modes=2, 

153 ) 

154 

155 

156# --- 2. Optimized Kernels (Using diag_expm) --- 

157 

158@partial(jax.jit, static_argnames=["N", "max_l"]) 

159def _Amp_Damp_Kraus_Map_JIT(N, err_prob, max_l): 

160 n_op = num(N).data 

161 a_op = destroy(N).data 

162 

163 a_powers = jnp.stack( 

164 [jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)]) 

165 

166 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

167 # FIX: Use diag_expm 

168 middle_op = diag_expm(n_op * log_term) 

169 

170 def compute_op(l): 

171 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1))) 

172 a_pow_l = a_powers[l] 

173 return prefactor * (middle_op @ a_pow_l) 

174 

175 ls = jnp.arange(max_l + 1) 

176 return jax.vmap(compute_op)(ls) 

177 

178def Amp_Damp(N, err_prob, max_l): 

179 kmap = lambda params: Qarray.create( 

180 _Amp_Damp_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]), 

181 dims=[[N], [N]], 

182 bdims=(params["max_l"] + 1,) 

183 ) 

184 return Gate.create( 

185 N, 

186 name="Amp_Damp", 

187 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

188 gen_KM=kmap, 

189 num_modes=1, 

190 ) 

191 

192 

193@partial(jax.jit, static_argnames=["N", "max_l"]) 

194def _Amp_Gain_Kraus_Map_JIT(N, err_prob, max_l): 

195 n_op = num(N).data 

196 adag_op = create(N).data 

197 

198 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

199 # FIX: Use diag_expm 

200 middle_op = diag_expm(n_op * log_term) 

201 

202 def compute_op(l): 

203 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1))) 

204 adag_pow_l = jnp.linalg.matrix_power(adag_op, l) 

205 return prefactor * (adag_pow_l @ middle_op) 

206 

207 ls = jnp.arange(max_l + 1) 

208 return jax.vmap(compute_op)(ls) 

209 

210def Amp_Gain(N, err_prob, max_l): 

211 kmap = lambda params: Qarray.create( 

212 _Amp_Gain_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]), 

213 dims=[[N], [N]], 

214 bdims=(params["max_l"] + 1,) 

215 ) 

216 return Gate.create( 

217 N, 

218 name="Amp_Gain", 

219 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

220 gen_KM=kmap, 

221 num_modes=1, 

222 ) 

223 

224 

225@partial(jax.jit, static_argnames=["N", "max_l"]) 

226def _Thermal_Ch_Kraus_Map_JIT(N, err_prob, n_bar, max_l): 

227 a_op = destroy(N).data 

228 adag_op = create(N).data 

229 n_op = num(N).data 

230 

231 a_powers = jnp.stack([jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)]) 

232 adag_powers = jnp.stack([jnp.linalg.matrix_power(adag_op, i) for i in range(max_l + 1)]) 

233 

234 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

235 # FIX: Use diag_expm 

236 middle_op = diag_expm(n_op * log_term) 

237 

238 def compute_single_op(idx): 

239 l = idx // (max_l + 1) 

240 k = idx % (max_l + 1) 

241 

242 fact_l = jnp.exp(gammaln(l + 1)) 

243 fact_k = jnp.exp(gammaln(k + 1)) 

244 

245 term_k = jnp.power(err_prob * (1.0 + n_bar), k) 

246 term_l = jnp.power(err_prob * n_bar, l) 

247 

248 prefactor = jnp.sqrt( (term_k * term_l) / (fact_k * fact_l) ) 

249 op_k = a_powers[k] 

250 op_l = adag_powers[l] 

251 

252 return prefactor * (middle_op @ op_k @ op_l) 

253 

254 indices = jnp.arange((max_l + 1)**2) 

255 return jax.vmap(compute_single_op)(indices) 

256 

257def Thermal_Ch(N, err_prob, n_bar, max_l): 

258 kmap = lambda params: Qarray.create( 

259 _Thermal_Ch_Kraus_Map_JIT(params["N"], params["err_prob"], params["n_bar"], params["max_l"]), 

260 dims=[[N], [N]], 

261 bdims=((params["max_l"] + 1)**2,) 

262 ) 

263 return Gate.create( 

264 N, 

265 name="Thermal_Ch", 

266 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l, "N": N}, 

267 gen_KM=kmap, 

268 num_modes=1, 

269 ) 

270 

271 

272@partial(jax.jit, static_argnames=["N", "max_l"]) 

273def _Dephasing_Ch_Kraus_Map_JIT(N, ws, phis, max_l): 

274 n_op = num(N).data 

275 def compute_op(w, phi): 

276 # FIX: Use diag_expm 

277 op = diag_expm(1.0j * phi * n_op) 

278 return jnp.sqrt(w) * op 

279 return jax.vmap(compute_op)(ws, phis) 

280 

281def Dephasing_Ch(N, err_prob, max_l): 

282 xs, ws_raw = hermgauss(max_l) 

283 phis = jnp.sqrt(2*err_prob)*xs 

284 ws = 1/jnp.sqrt(jnp.pi)*ws_raw 

285 

286 kmap = lambda params: Qarray.create( 

287 _Dephasing_Ch_Kraus_Map_JIT(params["N"], ws, phis, params["max_l"]), 

288 dims=[[N], [N]], 

289 bdims=(params["max_l"],) 

290 ) 

291 return Gate.create( 

292 N, 

293 name="Dephasing_Ch", 

294 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

295 gen_KM=kmap, 

296 num_modes=1, 

297 ) 

298 

299 

300def selfKerr(N, K): 

301 a = destroy(N) 

302 return Gate.create( 

303 N, 

304 name="selfKerr", 

305 params={"Kerr": K}, 

306 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(), 

307 num_modes=1, 

308 ) 

309 

310 

311@partial(jax.jit, static_argnames=["N", "max_l"]) 

312def _Dephasing_Reset_Kraus_Map_JIT(N, p, t_rst, chi, max_l): 

313 g = basis(2, 0).data 

314 e = basis(2, 1).data 

315 gg = g @ jnp.conj(g.T) 

316 ee = e @ jnp.conj(e.T) 

317 ge = g @ jnp.conj(e.T) 

318 

319 n_op = num(N).data 

320 I_N = jnp.eye(N) 

321 

322 ls_all = jnp.arange(2, max_l+1) 

323 norm_terms = -(jnp.log(p) * jnp.power(p, (ls_all - 2) / (max_l - 1))) / (max_l - 1) 

324 normalization_factor = (1 - p) / jnp.sum(norm_terms) 

325 

326 def compute_op(l): 

327 def branch_0(_): 

328 return jnp.kron(gg, I_N) 

329 

330 def branch_1(_): 

331 # FIX: Use diag_expm 

332 op_osc = diag_expm(-1.0j * chi * t_rst * n_op) 

333 return jnp.sqrt(p) * jnp.kron(ee, op_osc) 

334 

335 def branch_rest(_): 

336 term_val = -(jnp.log(p) * jnp.power(p, (l - 2) / (max_l - 1))) / (max_l - 1) 

337 prefactor = jnp.sqrt(term_val) * jnp.sqrt(normalization_factor) 

338 

339 exponent = -1.0j * chi * t_rst * (l - 2) / (max_l - 1) 

340 # FIX: Use diag_expm 

341 op_osc = diag_expm(exponent * n_op) 

342 return prefactor * jnp.kron(ge, op_osc) 

343 

344 return jax.lax.cond( 

345 l == 0, 

346 branch_0, 

347 lambda _: jax.lax.cond(l == 1, branch_1, branch_rest, operand=None), 

348 operand=None 

349 ) 

350 

351 ls = jnp.arange(max_l+1) 

352 return jax.vmap(compute_op)(ls) 

353 

354def Dephasing_Reset(N, p, t_rst, chi, max_l): 

355 kmap = lambda params: Qarray.create( 

356 _Dephasing_Reset_Kraus_Map_JIT( 

357 params["N"], params["p"], params["t_rst"], params["chi"], params["max_l"] 

358 ), 

359 dims=[[2, N], [2, N]], 

360 bdims=(params["max_l"]+1,) 

361 ) 

362 

363 return Gate.create( 

364 [2, N], 

365 name="Dephasing_Reset", 

366 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l, "N": N}, 

367 gen_KM=kmap, 

368 num_modes=2, 

369 )