Coverage for jaxquantum / circuits / library / oscillator.py: 31%

144 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 20:38 +0000

1"""Oscillator gates.""" 

2 

3from jaxquantum.core.operators import (displace, basis, destroy, create, num, 

4 identity) 

5from jaxquantum.circuits.gates import Gate 

6from jax.scipy.special import factorial, gammaln 

7import jax.numpy as jnp 

8from jaxquantum import Qarray 

9from jaxquantum.utils import hermgauss 

10from functools import partial 

11from jax import jit 

12import jax 

13 

14def diag_expm(diag_matrix): 

15 """Computes expm of a diagonal matrix efficiently (O(N) instead of O(N^3)).""" 

16 # Extract diagonal, exponentiate elements, put back on diagonal 

17 return jnp.diag(jnp.exp(jnp.diagonal(diag_matrix))) 

18 

19 

20 

21def D(N, alpha, ts=None, c_ops=None): 

22 """Displacement gate. 

23 

24 Args: 

25 N: Hilbert space dimension. 

26 alpha: Displacement amplitude. 

27 ts: Optional time array for hamiltonian simulation. 

28 c_ops: Optional collapse operators. 

29 

30 Returns: 

31 Displacement gate. 

32 """ 

33 gen_Ht = None 

34 if ts is not None: 

35 delta_t = ts[-1] - ts[0] 

36 amp = 1j * alpha / delta_t 

37 a = destroy(N) 

38 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag()) 

39 

40 return Gate.create( 

41 N, 

42 name="D", 

43 params={"alpha": alpha}, 

44 gen_U=lambda params: displace(N, params["alpha"]), 

45 gen_Ht=gen_Ht, 

46 ts=ts, 

47 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops, 

48 num_modes=1, 

49 ) 

50 

51 

52def CD(N, beta, ts=None): 

53 """Conditional displacement gate. 

54 

55 Args: 

56 N: Hilbert space dimension. 

57 beta: Conditional displacement amplitude. 

58 ts: Optional time sequence for hamiltonian simulation. 

59 

60 Returns: 

61 Conditional displacement gate. 

62 """ 

63 g = basis(2, 0) 

64 e = basis(2, 1) 

65 

66 gg = g @ g.dag() 

67 ee = e @ e.dag() 

68 

69 gen_Ht = None 

70 if ts is not None: 

71 delta_t = ts[-1] - ts[0] 

72 amp = 1j * beta / delta_t / 2 

73 a = destroy(N) 

74 gen_Ht = lambda params: lambda t: ( 

75 gg 

76 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee 

77 ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

78 ) 

79 

80 return Gate.create( 

81 [2, N], 

82 name="CD", 

83 params={"beta": beta}, 

84 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2)) 

85 + (ee ^ displace(N, -params["beta"] / 2)), 

86 gen_Ht=gen_Ht, 

87 ts=ts, 

88 num_modes=2, 

89 ) 

90 

91 

92def ECD(N, beta, ts=None): 

93 """Echoed conditional displacement gate. 

94 

95 Args: 

96 N: Hilbert space dimension. 

97 beta: Conditional displacement amplitude. 

98 ts: Optional time sequence for hamiltonian simulation. 

99 

100 Returns: 

101 Echoed conditional displacement gate. 

102 """ 

103 g = basis(2, 0) 

104 e = basis(2, 1) 

105 

106 eg = e @ g.dag() 

107 ge = g @ e.dag() 

108 

109 # gen_Ht = None 

110 # if ts is not None: 

111 # delta_t = ts[-1] - ts[0] 

112 # amp = 1j * beta / delta_t / 2 

113 # a = destroy(N) 

114 # gen_Ht = lambda params: lambda t: ( 

115 # eg 

116 # ^ (jnp.conj(amp) * a + amp * a.dag()) + ge 

117 # ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

118 # ) 

119 

120 return Gate.create( 

121 [2, N], 

122 name="ECD", 

123 params={"beta": beta}, 

124 gen_U=lambda params: (eg ^ displace(N, params["beta"] / 2)) 

125 + (ge ^ displace(N, -params["beta"] / 2)), 

126 gen_Ht=None, 

127 ts=ts, 

128 num_modes=2, 

129 ) 

130 

131def CR(N, theta): 

132 """Conditional rotation gate. 

133 

134 Args: 

135 N: Hilbert space dimension. 

136 theta: Conditional rotation angle. 

137 

138 Returns: 

139 Conditional rotation gate. 

140 """ 

141 g = basis(2, 0) 

142 e = basis(2, 1) 

143 

144 gg = g @ g.dag() 

145 ee = e @ e.dag() 

146 

147 

148 return Gate.create( 

149 [2, N], 

150 name="CR", 

151 params={"theta": theta}, 

152 gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm()) 

153 + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()), 

154 num_modes=2, 

155 ) 

156 

157 

158# --- 2. Optimized Kernels (Using diag_expm) --- 

159 

160@partial(jax.jit, static_argnames=["N", "max_l"]) 

161def _Amp_Damp_Kraus_Map_JIT(N, err_prob, max_l): 

162 n_op = num(N).data 

163 a_op = destroy(N).data 

164 

165 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

166 # FIX: Use diag_expm 

167 middle_op = diag_expm(n_op * log_term) 

168 

169 def compute_op(l): 

170 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1))) 

171 a_pow_l = jnp.linalg.matrix_power(a_op, l) 

172 return prefactor * (middle_op @ a_pow_l) 

173 

174 ls = jnp.arange(max_l + 1) 

175 return jax.vmap(compute_op)(ls) 

176 

177def Amp_Damp(N, err_prob, max_l): 

178 kmap = lambda params: Qarray.create( 

179 _Amp_Damp_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]), 

180 dims=[[N], [N]], 

181 bdims=(params["max_l"] + 1,) 

182 ) 

183 return Gate.create( 

184 N, 

185 name="Amp_Damp", 

186 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

187 gen_KM=kmap, 

188 num_modes=1, 

189 ) 

190 

191 

192@partial(jax.jit, static_argnames=["N", "max_l"]) 

193def _Amp_Gain_Kraus_Map_JIT(N, err_prob, max_l): 

194 n_op = num(N).data 

195 adag_op = create(N).data 

196 

197 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

198 # FIX: Use diag_expm 

199 middle_op = diag_expm(n_op * log_term) 

200 

201 def compute_op(l): 

202 prefactor = jnp.sqrt(jnp.power(err_prob, l) / jnp.exp(gammaln(l + 1))) 

203 adag_pow_l = jnp.linalg.matrix_power(adag_op, l) 

204 return prefactor * (adag_pow_l @ middle_op) 

205 

206 ls = jnp.arange(max_l + 1) 

207 return jax.vmap(compute_op)(ls) 

208 

209def Amp_Gain(N, err_prob, max_l): 

210 kmap = lambda params: Qarray.create( 

211 _Amp_Gain_Kraus_Map_JIT(params["N"], params["err_prob"], params["max_l"]), 

212 dims=[[N], [N]], 

213 bdims=(params["max_l"] + 1,) 

214 ) 

215 return Gate.create( 

216 N, 

217 name="Amp_Gain", 

218 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

219 gen_KM=kmap, 

220 num_modes=1, 

221 ) 

222 

223 

224@partial(jax.jit, static_argnames=["N", "max_l"]) 

225def _Thermal_Ch_Kraus_Map_JIT(N, err_prob, n_bar, max_l): 

226 a_op = destroy(N).data 

227 adag_op = create(N).data 

228 n_op = num(N).data 

229 

230 a_powers = jnp.stack([jnp.linalg.matrix_power(a_op, i) for i in range(max_l + 1)]) 

231 adag_powers = jnp.stack([jnp.linalg.matrix_power(adag_op, i) for i in range(max_l + 1)]) 

232 

233 log_term = jnp.log(jnp.sqrt(1.0 - err_prob)) 

234 # FIX: Use diag_expm 

235 middle_op = diag_expm(n_op * log_term) 

236 

237 def compute_single_op(idx): 

238 l = idx // (max_l + 1) 

239 k = idx % (max_l + 1) 

240 

241 fact_l = jnp.exp(gammaln(l + 1)) 

242 fact_k = jnp.exp(gammaln(k + 1)) 

243 

244 term_k = jnp.power(err_prob * (1.0 + n_bar), k) 

245 term_l = jnp.power(err_prob * n_bar, l) 

246 

247 prefactor = jnp.sqrt( (term_k * term_l) / (fact_k * fact_l) ) 

248 op_k = a_powers[k] 

249 op_l = adag_powers[l] 

250 

251 return prefactor * (middle_op @ op_k @ op_l) 

252 

253 indices = jnp.arange((max_l + 1)**2) 

254 return jax.vmap(compute_single_op)(indices) 

255 

256def Thermal_Ch(N, err_prob, n_bar, max_l): 

257 kmap = lambda params: Qarray.create( 

258 _Thermal_Ch_Kraus_Map_JIT(params["N"], params["err_prob"], params["n_bar"], params["max_l"]), 

259 dims=[[N], [N]], 

260 bdims=((params["max_l"] + 1)**2,) 

261 ) 

262 return Gate.create( 

263 N, 

264 name="Thermal_Ch", 

265 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l, "N": N}, 

266 gen_KM=kmap, 

267 num_modes=1, 

268 ) 

269 

270 

271@partial(jax.jit, static_argnames=["N", "max_l"]) 

272def _Dephasing_Ch_Kraus_Map_JIT(N, ws, phis, max_l): 

273 n_op = num(N).data 

274 def compute_op(w, phi): 

275 # FIX: Use diag_expm 

276 op = diag_expm(1.0j * phi * n_op) 

277 return jnp.sqrt(w) * op 

278 return jax.vmap(compute_op)(ws, phis) 

279 

280def Dephasing_Ch(N, err_prob, max_l): 

281 xs, ws_raw = hermgauss(max_l) 

282 phis = jnp.sqrt(2*err_prob)*xs 

283 ws = 1/jnp.sqrt(jnp.pi)*ws_raw 

284 

285 kmap = lambda params: Qarray.create( 

286 _Dephasing_Ch_Kraus_Map_JIT(params["N"], ws, phis, params["max_l"]), 

287 dims=[[N], [N]], 

288 bdims=(params["max_l"],) 

289 ) 

290 return Gate.create( 

291 N, 

292 name="Dephasing_Ch", 

293 params={"err_prob": err_prob, "max_l": max_l, "N": N}, 

294 gen_KM=kmap, 

295 num_modes=1, 

296 ) 

297 

298 

299def selfKerr(N, K): 

300 a = destroy(N) 

301 return Gate.create( 

302 N, 

303 name="selfKerr", 

304 params={"Kerr": K}, 

305 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(), 

306 num_modes=1, 

307 ) 

308 

309 

310@partial(jax.jit, static_argnames=["N", "max_l"]) 

311def _Dephasing_Reset_Kraus_Map_JIT(N, p, t_rst, chi, max_l): 

312 g = basis(2, 0).data 

313 e = basis(2, 1).data 

314 gg = g @ jnp.conj(g.T) 

315 ee = e @ jnp.conj(e.T) 

316 ge = g @ jnp.conj(e.T) 

317 

318 n_op = num(N).data 

319 I_N = jnp.eye(N) 

320 

321 ls_all = jnp.arange(2, max_l+1) 

322 norm_terms = -(jnp.log(p) * jnp.power(p, (ls_all - 2) / (max_l - 1))) / (max_l - 1) 

323 normalization_factor = (1 - p) / jnp.sum(norm_terms) 

324 

325 def compute_op(l): 

326 def branch_0(_): 

327 return jnp.kron(gg, I_N) 

328 

329 def branch_1(_): 

330 # FIX: Use diag_expm 

331 op_osc = diag_expm(-1.0j * chi * t_rst * n_op) 

332 return jnp.sqrt(p) * jnp.kron(ee, op_osc) 

333 

334 def branch_rest(_): 

335 term_val = -(jnp.log(p) * jnp.power(p, (l - 2) / (max_l - 1))) / (max_l - 1) 

336 prefactor = jnp.sqrt(term_val) * jnp.sqrt(normalization_factor) 

337 

338 exponent = -1.0j * chi * t_rst * (l - 2) / (max_l - 1) 

339 # FIX: Use diag_expm 

340 op_osc = diag_expm(exponent * n_op) 

341 return prefactor * jnp.kron(ge, op_osc) 

342 

343 return jax.lax.cond( 

344 l == 0, 

345 branch_0, 

346 lambda _: jax.lax.cond(l == 1, branch_1, branch_rest, operand=None), 

347 operand=None 

348 ) 

349 

350 ls = jnp.arange(max_l+1) 

351 return jax.vmap(compute_op)(ls) 

352 

353def Dephasing_Reset(N, p, t_rst, chi, max_l): 

354 kmap = lambda params: Qarray.create( 

355 _Dephasing_Reset_Kraus_Map_JIT( 

356 params["N"], params["p"], params["t_rst"], params["chi"], params["max_l"] 

357 ), 

358 dims=[[2, N], [2, N]], 

359 bdims=(params["max_l"]+1,) 

360 ) 

361 

362 return Gate.create( 

363 [2, N], 

364 name="Dephasing_Reset", 

365 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l, "N": N}, 

366 gen_KM=kmap, 

367 num_modes=2, 

368 )