Coverage for jaxquantum / circuits / library / oscillator.py: 30%

145 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-28 21:05 +0000

1"""Oscillator gates.""" 

2 

3from jaxquantum.core.operators import (displace, basis, destroy, create, num, 

4 identity) 

5from jaxquantum.circuits.gates import Gate 

6from jax.scipy.special import factorial, gammaln 

7import jax.numpy as jnp 

8from jaxquantum import Qarray 

9from jaxquantum.utils import hermgauss 

10from functools import partial 

11from jax import jit 

12import jax 

13 

14def diag_expm(diag_matrix): 

15 """Computes expm of a diagonal matrix efficiently (O(N) instead of O(N^3)).""" 

16 # Extract diagonal, exponentiate elements, put back on diagonal 

17 return jnp.diag(jnp.exp(jnp.diagonal(diag_matrix))) 

18 

19 

20 

21def D(N, alpha, ts=None, c_ops=None): 

22 """Displacement gate. 

23 

24 Args: 

25 N: Hilbert space dimension. 

26 alpha: Displacement amplitude. 

27 ts: Optional time array for hamiltonian simulation. 

28 c_ops: Optional collapse operators. 

29 

30 Returns: 

31 Displacement gate. 

32 """ 

33 gen_Ht = None 

34 if ts is not None: 

35 delta_t = ts[-1] - ts[0] 

36 amp = 1j * alpha / delta_t 

37 a = destroy(N) 

38 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag()) 

39 

40 return Gate.create( 

41 N, 

42 name="D", 

43 params={"alpha": alpha}, 

44 gen_U=lambda params: displace(N, params["alpha"]), 

45 gen_Ht=gen_Ht, 

46 ts=ts, 

47 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops, 

48 num_modes=1, 

49 ) 

50 

51 

52def CD(N, beta, ts=None): 

53 """Conditional displacement gate. 

54 

55 Args: 

56 N: Hilbert space dimension. 

57 beta: Conditional displacement amplitude. 

58 ts: Optional time sequence for hamiltonian simulation. 

59 

60 Returns: 

61 Conditional displacement gate. 

62 """ 

63 g = basis(2, 0) 

64 e = basis(2, 1) 

65 

66 gg = g @ g.dag() 

67 ee = e @ e.dag() 

68 

69 gen_Ht = None 

70 if ts is not None: 

71 delta_t = ts[-1] - ts[0] 

72 amp = 1j * beta / delta_t / 2 

73 a = destroy(N) 

74 gen_Ht = lambda params: lambda t: ( 

75 gg 

76 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee 

77 ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

78 ) 

79 

80 return Gate.create( 

81 [2, N], 

82 name="CD", 

83 params={"beta": beta}, 

84 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2)) 

85 + (ee ^ displace(N, -params["beta"] / 2)), 

86 gen_Ht=gen_Ht, 

87 ts=ts, 

88 num_modes=2, 

89 ) 

90 

91 

92def ECD(N, beta, ts=None): 

93 """Echoed conditional displacement gate. 

94 

95 Args: 

96 N: Hilbert space dimension. 

97 beta: Conditional displacement amplitude. 

98 ts: Optional time sequence for hamiltonian simulation. 

99 

100 Returns: 

101 Echoed conditional displacement gate. 

102 """ 

103 g = basis(2, 0) 

104 e = basis(2, 1) 

105 

106 eg = e @ g.dag() 

107 ge = g @ e.dag() 

108 

109 # gen_Ht = None 

110 # if ts is not None: 

111 # delta_t = ts[-1] - ts[0] 

112 # amp = 1j * beta / delta_t / 2 

113 # a = destroy(N) 

114 # gen_Ht = lambda params: lambda t: ( 

115 # eg 

116 # ^ (jnp.conj(amp) * a + amp * a.dag()) + ge 

117 # ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

118 # ) 

119 

120 return Gate.create( 

121 [2, N], 

122 name="ECD", 

123 params={"beta": beta}, 

124 gen_U=lambda params: (eg ^ displace(N, params["beta"] / 2)) 

125 + (ge ^ displace(N, -params["beta"] / 2)), 

126 gen_Ht=None, 

127 ts=ts, 

128 num_modes=2, 

129 ) 

130 

131def CR(N, theta): 

132 """Conditional rotation gate. 

133 

134 Args: 

135 N: Hilbert space dimension. 

136 theta: Conditional rotation angle. 

137 

138 Returns: 

139 Conditional rotation gate. 

140 """ 

141 g = basis(2, 0) 

142 e = basis(2, 1) 

143 

144 gg = g @ g.dag() 

145 ee = e @ e.dag() 

146 

147 

148 return Gate.create( 

149 [2, N], 

150 name="CR", 

151 params={"theta": theta}, 

152 gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm()) 

153 + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()), 

154 num_modes=2, 

155 ) 

156 

157 

158# --- 2. Optimized Kernels (Using diag_expm) --- 

159 

160@partial(jax.jit, static_argnames=["N", "max_l"]) 

161def _Amp_Damp_Kraus_Map_JIT(N, err_prob, max_l): 

162 n_op = num(N).data 

163 a_op = destroy(N).data 

164 

165 a_powers = jnp.stack( 

166 [jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)]) 

167 

168 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

169 # FIX: Use diag_expm 

170 middle_op = diag_expm(n_op * log_term) 

171 

172 def compute_op(l): 

173 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1))) 

174 a_pow_l = a_powers[l] 

175 return prefactor * (middle_op @ a_pow_l) 

176 

177 ls = jnp.arange(max_l + 1) 

178 return jax.vmap(compute_op)(ls) 

179 

180def Amp_Damp(N, err_prob, max_l): 

181 kmap = lambda params: Qarray.create( 

182 _Amp_Damp_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]), 

183 dims=[[N], [N]], 

184 bdims=(params["max_l"] + 1,) 

185 ) 

186 return Gate.create( 

187 N, 

188 name="Amp_Damp", 

189 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

190 gen_KM=kmap, 

191 num_modes=1, 

192 ) 

193 

194 

195@partial(jax.jit, static_argnames=["N", "max_l"]) 

196def _Amp_Gain_Kraus_Map_JIT(N, err_prob, max_l): 

197 n_op = num(N).data 

198 adag_op = create(N).data 

199 

200 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

201 # FIX: Use diag_expm 

202 middle_op = diag_expm(n_op * log_term) 

203 

204 def compute_op(l): 

205 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1))) 

206 adag_pow_l = jnp.linalg.matrix_power(adag_op, l) 

207 return prefactor * (adag_pow_l @ middle_op) 

208 

209 ls = jnp.arange(max_l + 1) 

210 return jax.vmap(compute_op)(ls) 

211 

212def Amp_Gain(N, err_prob, max_l): 

213 kmap = lambda params: Qarray.create( 

214 _Amp_Gain_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]), 

215 dims=[[N], [N]], 

216 bdims=(params["max_l"] + 1,) 

217 ) 

218 return Gate.create( 

219 N, 

220 name="Amp_Gain", 

221 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

222 gen_KM=kmap, 

223 num_modes=1, 

224 ) 

225 

226 

227@partial(jax.jit, static_argnames=["N", "max_l"]) 

228def _Thermal_Ch_Kraus_Map_JIT(N, err_prob, n_bar, max_l): 

229 a_op = destroy(N).data 

230 adag_op = create(N).data 

231 n_op = num(N).data 

232 

233 a_powers = jnp.stack([jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)]) 

234 adag_powers = jnp.stack([jnp.linalg.matrix_power(adag_op, i) for i in range(max_l + 1)]) 

235 

236 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

237 # FIX: Use diag_expm 

238 middle_op = diag_expm(n_op * log_term) 

239 

240 def compute_single_op(idx): 

241 l = idx // (max_l + 1) 

242 k = idx % (max_l + 1) 

243 

244 fact_l = jnp.exp(gammaln(l + 1)) 

245 fact_k = jnp.exp(gammaln(k + 1)) 

246 

247 term_k = jnp.power(err_prob * (1.0 + n_bar), k) 

248 term_l = jnp.power(err_prob * n_bar, l) 

249 

250 prefactor = jnp.sqrt( (term_k * term_l) / (fact_k * fact_l) ) 

251 op_k = a_powers[k] 

252 op_l = adag_powers[l] 

253 

254 return prefactor * (middle_op @ op_k @ op_l) 

255 

256 indices = jnp.arange((max_l + 1)**2) 

257 return jax.vmap(compute_single_op)(indices) 

258 

259def Thermal_Ch(N, err_prob, n_bar, max_l): 

260 kmap = lambda params: Qarray.create( 

261 _Thermal_Ch_Kraus_Map_JIT(params["N"], params["err_prob"], params["n_bar"], params["max_l"]), 

262 dims=[[N], [N]], 

263 bdims=((params["max_l"] + 1)**2,) 

264 ) 

265 return Gate.create( 

266 N, 

267 name="Thermal_Ch", 

268 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l, "N": N}, 

269 gen_KM=kmap, 

270 num_modes=1, 

271 ) 

272 

273 

274@partial(jax.jit, static_argnames=["N", "max_l"]) 

275def _Dephasing_Ch_Kraus_Map_JIT(N, ws, phis, max_l): 

276 n_op = num(N).data 

277 def compute_op(w, phi): 

278 # FIX: Use diag_expm 

279 op = diag_expm(1.0j * phi * n_op) 

280 return jnp.sqrt(w) * op 

281 return jax.vmap(compute_op)(ws, phis) 

282 

283def Dephasing_Ch(N, err_prob, max_l): 

284 xs, ws_raw = hermgauss(max_l) 

285 phis = jnp.sqrt(2*err_prob)*xs 

286 ws = 1/jnp.sqrt(jnp.pi)*ws_raw 

287 

288 kmap = lambda params: Qarray.create( 

289 _Dephasing_Ch_Kraus_Map_JIT(params["N"], ws, phis, params["max_l"]), 

290 dims=[[N], [N]], 

291 bdims=(params["max_l"],) 

292 ) 

293 return Gate.create( 

294 N, 

295 name="Dephasing_Ch", 

296 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

297 gen_KM=kmap, 

298 num_modes=1, 

299 ) 

300 

301 

302def selfKerr(N, K): 

303 a = destroy(N) 

304 return Gate.create( 

305 N, 

306 name="selfKerr", 

307 params={"Kerr": K}, 

308 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(), 

309 num_modes=1, 

310 ) 

311 

312 

313@partial(jax.jit, static_argnames=["N", "max_l"]) 

314def _Dephasing_Reset_Kraus_Map_JIT(N, p, t_rst, chi, max_l): 

315 g = basis(2, 0).data 

316 e = basis(2, 1).data 

317 gg = g @ jnp.conj(g.T) 

318 ee = e @ jnp.conj(e.T) 

319 ge = g @ jnp.conj(e.T) 

320 

321 n_op = num(N).data 

322 I_N = jnp.eye(N) 

323 

324 ls_all = jnp.arange(2, max_l+1) 

325 norm_terms = -(jnp.log(p) * jnp.power(p, (ls_all - 2) / (max_l - 1))) / (max_l - 1) 

326 normalization_factor = (1 - p) / jnp.sum(norm_terms) 

327 

328 def compute_op(l): 

329 def branch_0(_): 

330 return jnp.kron(gg, I_N) 

331 

332 def branch_1(_): 

333 # FIX: Use diag_expm 

334 op_osc = diag_expm(-1.0j * chi * t_rst * n_op) 

335 return jnp.sqrt(p) * jnp.kron(ee, op_osc) 

336 

337 def branch_rest(_): 

338 term_val = -(jnp.log(p) * jnp.power(p, (l - 2) / (max_l - 1))) / (max_l - 1) 

339 prefactor = jnp.sqrt(term_val) * jnp.sqrt(normalization_factor) 

340 

341 exponent = -1.0j * chi * t_rst * (l - 2) / (max_l - 1) 

342 # FIX: Use diag_expm 

343 op_osc = diag_expm(exponent * n_op) 

344 return prefactor * jnp.kron(ge, op_osc) 

345 

346 return jax.lax.cond( 

347 l == 0, 

348 branch_0, 

349 lambda _: jax.lax.cond(l == 1, branch_1, branch_rest, operand=None), 

350 operand=None 

351 ) 

352 

353 ls = jnp.arange(max_l+1) 

354 return jax.vmap(compute_op)(ls) 

355 

356def Dephasing_Reset(N, p, t_rst, chi, max_l): 

357 kmap = lambda params: Qarray.create( 

358 _Dephasing_Reset_Kraus_Map_JIT( 

359 params["N"], params["p"], params["t_rst"], params["chi"], params["max_l"] 

360 ), 

361 dims=[[2, N], [2, N]], 

362 bdims=(params["max_l"]+1,) 

363 ) 

364 

365 return Gate.create( 

366 [2, N], 

367 name="Dephasing_Reset", 

368 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l, "N": N}, 

369 gen_KM=kmap, 

370 num_modes=2, 

371 )