Coverage for jaxquantum/circuits/library/oscillator.py: 45%

76 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1"""Oscillator gates.""" 

2 

3from jaxquantum.core.operators import (displace, basis, destroy, create, num, 

4 identity) 

5from jaxquantum.circuits.gates import Gate 

6from jax.scipy.special import factorial 

7import jax.numpy as jnp 

8from jaxquantum import Qarray 

9from jaxquantum.utils import hermgauss 

10 

11 

12def D(N, alpha, ts=None, c_ops=None): 

13 """Displacement gate. 

14 

15 Args: 

16 N: Hilbert space dimension. 

17 alpha: Displacement amplitude. 

18 ts: Optional time array for hamiltonian simulation. 

19 c_ops: Optional collapse operators. 

20 

21 Returns: 

22 Displacement gate. 

23 """ 

24 gen_Ht = None 

25 if ts is not None: 

26 delta_t = ts[-1] - ts[0] 

27 amp = 1j * alpha / delta_t 

28 a = destroy(N) 

29 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag()) 

30 

31 return Gate.create( 

32 N, 

33 name="D", 

34 params={"alpha": alpha}, 

35 gen_U=lambda params: displace(N, params["alpha"]), 

36 gen_Ht=gen_Ht, 

37 ts=ts, 

38 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops, 

39 num_modes=1, 

40 ) 

41 

42 

43def CD(N, beta, ts=None): 

44 """Conditional displacement gate. 

45 

46 Args: 

47 N: Hilbert space dimension. 

48 beta: Conditional displacement amplitude. 

49 ts: Optional time sequence for hamiltonian simulation. 

50 

51 Returns: 

52 Conditional displacement gate. 

53 """ 

54 g = basis(2, 0) 

55 e = basis(2, 1) 

56 

57 gg = g @ g.dag() 

58 ee = e @ e.dag() 

59 

60 gen_Ht = None 

61 if ts is not None: 

62 delta_t = ts[-1] - ts[0] 

63 amp = 1j * beta / delta_t / 2 

64 a = destroy(N) 

65 gen_Ht = lambda params: lambda t: ( 

66 gg 

67 ^ (jnp.conj(amp) * a + amp * a.dag()) + ee 

68 ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

69 ) 

70 

71 return Gate.create( 

72 [2, N], 

73 name="CD", 

74 params={"beta": beta}, 

75 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2)) 

76 + (ee ^ displace(N, -params["beta"] / 2)), 

77 gen_Ht=gen_Ht, 

78 ts=ts, 

79 num_modes=2, 

80 ) 

81 

82def CR(N, theta): 

83 """Conditional rotation gate. 

84 

85 Args: 

86 N: Hilbert space dimension. 

87 theta: Conditional rotation angle. 

88 

89 Returns: 

90 Conditional rotation gate. 

91 """ 

92 g = basis(2, 0) 

93 e = basis(2, 1) 

94 

95 gg = g @ g.dag() 

96 ee = e @ e.dag() 

97 

98 

99 return Gate.create( 

100 [2, N], 

101 name="CR", 

102 params={"theta": theta}, 

103 gen_U=lambda params: (gg ^ (-1.j*theta/2*destroy(N)@create(N)).expm()) 

104 + (ee ^ (1.j*theta/2*destroy(N)@create(N)).expm()), 

105 num_modes=2, 

106 ) 

107 

108 

109def _Ph_Loss_Kraus_Op(N, err_prob, l): 

110 """Returns the Kraus Operators for l-photon loss. 

111 

112 Args: 

113 N: Hilbert space dimension. 

114 err_prob: Error probability. 

115 l: Number of photons lost. 

116 

117 Returns: 

118 Kraus operator for l-photon loss. 

119 """ 

120 """ " Returns the Kraus Operators for l-photon loss with probability 

121 err_prob in a Hilbert Space of size N""" 

122 return ( 

123 jnp.sqrt(jnp.power(err_prob, l) / factorial(l)) 

124 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm() 

125 * destroy(N).powm(l) 

126 ) 

127 

128 

129def Amp_Damp(N, err_prob, max_l): 

130 """Amplitude damping channel. 

131 

132 Args: 

133 N: Hilbert space dimension. 

134 err_prob: Error probability. 

135 max_l: Maximum number of photons lost. 

136 

137 Returns: 

138 Amplitude damping channel. 

139 """ 

140 kmap = lambda params: Qarray.from_list( 

141 [_Ph_Loss_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)] 

142 ) 

143 return Gate.create( 

144 N, 

145 name="Amp_Damp", 

146 params={"err_prob": err_prob, "max_l": max_l}, 

147 gen_KM=kmap, 

148 num_modes=1, 

149 ) 

150 

151 

152def _Ph_Gain_Kraus_Op(N, err_prob, l): 

153 """Returns the Kraus Operators for l-photon gain. 

154 

155 Args: 

156 N: Hilbert space dimension. 

157 err_prob: Error probability. 

158 l: Number of photons gained. 

159 

160 Returns: 

161 Kraus operator for l-photon gain. 

162 """ 

163 """ " Returns the Kraus Operators for l-photon gain with probability 

164 err_prob in a Hilbert Space of size N""" 

165 return ( 

166 jnp.sqrt(jnp.power(err_prob, l) / factorial(l)) 

167 * create(N).powm(l) 

168 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm() 

169 ) 

170 

171 

172def Amp_Gain(N, err_prob, max_l): 

173 """Amplitude gain channel. 

174 

175 Args: 

176 N: Hilbert space dimension. 

177 err_prob: Error probability. 

178 max_l: Maximum number of photons gained. 

179 

180 Returns: 

181 Amplitude gain channel. 

182 """ 

183 kmap = lambda params: Qarray.from_list( 

184 [_Ph_Gain_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)] 

185 ) 

186 return Gate.create( 

187 N, 

188 name="Amp_Gain", 

189 params={"err_prob": err_prob, "max_l": max_l}, 

190 gen_KM=kmap, 

191 num_modes=1, 

192 ) 

193 

194 

195def _Thermal_Kraus_Op(N, err_prob, n_bar, l, k): 

196 """Returns the Kraus Operators for a thermal channel. 

197 

198 Args: 

199 N: Hilbert space dimension. 

200 err_prob: Error probability. 

201 n_bar: Average photon number. 

202 l: Number of photons gained. 

203 k: Number of photons lost. 

204 

205 Returns: 

206 Kraus operator for thermal channel. 

207 """ 

208 """ " Returns the Kraus Operators for a thermal channel with probability 

209 err_prob and average photon number n_bar in a Hilbert Space of size N""" 

210 return ( 

211 jnp.sqrt( 

212 jnp.power(err_prob * (1 + n_bar), k) 

213 * jnp.power(err_prob * n_bar, l) 

214 / factorial(l) 

215 / factorial(k) 

216 ) 

217 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm() 

218 * destroy(N).powm(k) 

219 * create(N).powm(l) 

220 ) 

221 

222 

223def Thermal_Ch(N, err_prob, n_bar, max_l): 

224 """Thermal channel. 

225 

226 Args: 

227 N: Hilbert space dimension. 

228 err_prob: Error probability. 

229 n_bar: Average photon number. 

230 max_l: Maximum number of photons gained/lost. 

231 

232 Returns: 

233 Thermal channel. 

234 """ 

235 kmap = lambda params: Qarray.from_list( 

236 [ 

237 _Thermal_Kraus_Op(N, err_prob, n_bar, l, k) 

238 for l in range(max_l + 1) 

239 for k in range(max_l + 1) 

240 ] 

241 ) 

242 return Gate.create( 

243 N, 

244 name="Thermal_Ch", 

245 params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l}, 

246 gen_KM=kmap, 

247 num_modes=1, 

248 ) 

249 

250 

251def _Dephasing_Kraus_Op(N, w, phi): 

252 """ " Returns the Kraus Operators for dephasing with weight w and phase phi 

253 in a Hilbert Space of size N""" 

254 return ( 

255 jnp.sqrt(w)*(1.j*phi*num(N)).expm() 

256 ) 

257 

258 

259def Dephasing_Ch(N, err_prob, max_l): 

260 """Dephasing channel. 

261 

262 Args: 

263 N: Hilbert space dimension. 

264 err_prob: Error probability. 

265 max_l: Maximum number of kraus operators. 

266 

267 Returns: 

268 Dephasing channel. 

269 """ 

270 

271 xs, ws = hermgauss(max_l) 

272 phis = jnp.sqrt(2*err_prob)*xs 

273 ws = 1/jnp.sqrt(jnp.pi)*ws 

274 

275 kmap = lambda params: Qarray.from_list( 

276 [_Dephasing_Kraus_Op(N, w, phi) for (w, phi) in zip(ws, phis)] 

277 ) 

278 return Gate.create( 

279 N, 

280 name="Amp_Gain", 

281 params={"err_prob": err_prob, "max_l": max_l}, 

282 gen_KM=kmap, 

283 num_modes=1, 

284 ) 

285 

286def selfKerr(N, K): 

287 """Self-Kerr interaction gate. 

288 

289 Args: 

290 N: Hilbert space dimension. 

291 K: Kerr coefficient. 

292 

293 Returns: 

294 Self-Kerr gate. 

295 """ 

296 a = destroy(N) 

297 return Gate.create( 

298 N, 

299 name="selfKerr", 

300 params={"Kerr": K}, 

301 gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(), 

302 num_modes=1, 

303 ) 

304 

305 

306def _Reset_Deph_Kraus_Op(N, p, t_rst, chi, l, max_l): 

307 """Returns the Kraus Operators for dephasing during reset. 

308 

309 Args: 

310 N: Hilbert space dimension. 

311 p: Reset error probability. 

312 t_rst: Reset time. 

313 chi: cross-Kerr strength between qubit and resonator. 

314 l: Operator index. 

315 max_l: Maximum number of operators. 

316 

317 Returns: 

318 Kraus operator for dephasing during reset. 

319 """ 

320 

321 if l == 0: 

322 K_0 = (basis(2, 0) @ basis(2, 0).dag()) ^ identity(N) 

323 return K_0 

324 if l == 1: 

325 K_1 = jnp.sqrt(p) * (basis(2, 1) @ basis(2, 1).dag()) ^ ( 

326 -1.j * chi * t_rst * num(N)).expm() 

327 return K_1 

328 

329 ls = jnp.arange(2, max_l, 1) 

330 

331 normalization_factor = (1 - p) / jnp.sum( 

332 -(jnp.log(p) * p ** ((ls - 2) / (max_l - 1))) / ((max_l - 1))) 

333 

334 prefactor = (jnp.sqrt(-(jnp.log(p) * p ** ((l - 2) / (max_l - 1))) / ( 

335 (max_l - 1))) * jnp.sqrt(normalization_factor)) 

336 

337 K_i = ( 

338 prefactor * 

339 ((basis(2, 0) @ basis(2, 1).dag()) ^ 

340 (-1.j * chi * t_rst * (l - 2) / (max_l - 1) * num(N)).expm()) 

341 ) 

342 

343 return K_i 

344 

345 

346def Dephasing_Reset(N, p, t_rst, chi, max_l): 

347 """Dephasing due to imperfect reset between a qubit and a resonator. 

348 

349 Args: 

350 N: Hilbert space dimension. 

351 p: Reset error probability. 

352 t_rst: Reset time. 

353 chi: Dephasing strength. 

354 max_l: Maximum number of operators. 

355 

356 Returns: 

357 Dephasing due to reset channel. 

358 """ 

359 

360 kmap = lambda params: Qarray.from_list( 

361 [_Reset_Deph_Kraus_Op(N, p, t_rst, chi, l, max_l) for l in 

362 range(max_l)] 

363 ) 

364 return Gate.create( 

365 [2, N], 

366 name="Dephasing_Reset", 

367 params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l}, 

368 gen_KM=kmap, 

369 num_modes=2, 

370 )