Coverage for jaxquantum / core / operators.py: 92%

102 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""States.""" 

2 

3from typing import List 

4from jax import config 

5from math import prod 

6 

7import jax.numpy as jnp 

8from jax.nn import one_hot 

9 

10from jaxquantum.core.qarray import Qarray, tensor, QarrayImplType 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15def _make_sparsedia(offsets: tuple, diags: "jnp.ndarray", dims=None) -> Qarray: 

16 """Build a ``Qarray[SparseDiaImpl]`` directly from padded diagonal arrays. 

17 

18 Avoids going through a dense intermediate (no ``jnp.diag`` round-trip). 

19 ``diags`` must already follow Convention A: diagonal at offset k has 

20 leading zeros at [0:k] (k ≥ 0) or trailing zeros at [n+k:] (k < 0). 

21 

22 Args: 

23 offsets: Sorted tuple of integer diagonal offsets. 

24 diags: JAX array of shape (n_diags, n) with padded values. 

25 dims: Optional quantum dims tuple. 

26 

27 Returns: 

28 A ``Qarray`` backed by ``SparseDiaImpl``. 

29 """ 

30 from jaxquantum.core.sparse_dia import SparseDiaImpl 

31 

32 impl = SparseDiaImpl.from_diags(offsets=offsets, diags=diags) 

33 return Qarray.create(impl.get_data(), dims=dims, implementation=QarrayImplType.SPARSE_DIA) 

34 

35 

36def sigmax(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

37 """σx 

38 

39 Args: 

40 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

41 

42 Returns: 

43 σx Pauli Operator 

44 """ 

45 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

46 # Offset -1: valid at [0:1] → diag[0] = A[1,0] = 1.0, diag[1] = 0 (trailing zero) 

47 # Offset +1: valid at [1:] → diag[0] = 0 (leading zero), diag[1] = A[0,1] = 1.0 

48 diags = jnp.array([[1.0, 0.0], [0.0, 1.0]]) 

49 return _make_sparsedia(offsets=(-1, 1), diags=diags) 

50 return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]]), implementation=implementation) 

51 

52 

53def sigmay(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

54 """σy 

55 

56 Returns: 

57 σy Pauli Operator 

58 """ 

59 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

60 diags = jnp.array([[1.0j, 0.0], [0.0, -1.0j]]) 

61 return _make_sparsedia(offsets=(-1, 1), diags=diags) 

62 return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]]), implementation=implementation) 

63 

64 

65def sigmaz(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

66 """σz 

67 

68 Returns: 

69 σz Pauli Operator 

70 """ 

71 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

72 diags = jnp.array([[1.0, -1.0]]) 

73 return _make_sparsedia(offsets=(0,), diags=diags) 

74 return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]]), implementation=implementation) 

75 

76 

77def hadamard(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

78 """H 

79 

80 Returns: 

81 H: Hadamard gate 

82 """ 

83 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

84 s = 1.0 / jnp.sqrt(2.0) 

85 # offset -1: valid at [0] → diag[0]=A[1,0]=s, diag[1]=0 (trailing zero) 

86 # offset 0: valid at [0:2] → diag[0]=A[0,0]=s, diag[1]=A[1,1]=-s 

87 # offset +1: valid at [1] → diag[0]=0 (leading zero), diag[1]=A[0,1]=s 

88 diags = jnp.array([[s, 0.0], [s, -s], [0.0, s]]) 

89 return _make_sparsedia(offsets=(-1, 0, 1), diags=diags) 

90 return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2), implementation=implementation) 

91 

92 

93def sigmam(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

94 """σ- 

95 

96 Returns: 

97 σ- Pauli Operator 

98 """ 

99 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

100 diags = jnp.array([[1.0, 0.0]]) 

101 return _make_sparsedia(offsets=(-1,), diags=diags) 

102 return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]]), implementation=implementation) 

103 

104 

105def sigmap(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

106 """σ+ 

107 

108 Returns: 

109 σ+ Pauli Operator 

110 """ 

111 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

112 diags = jnp.array([[0.0, 1.0]]) 

113 return _make_sparsedia(offsets=(1,), diags=diags) 

114 return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]]), implementation=implementation) 

115 

116 

117def qubit_rotation(theta: float, nx, ny, nz) -> Qarray: 

118 """Single qubit rotation. 

119 

120 Args: 

121 theta: rotation angle. 

122 nx: rotation axis x component. 

123 ny: rotation axis y component. 

124 nz: rotation axis z component. 

125 

126 Returns: 

127 Single qubit rotation operator. 

128 """ 

129 return jnp.cos(theta / 2) * identity(2) - 1j * jnp.sin(theta / 2) * ( 

130 nx * sigmax() + ny * sigmay() + nz * sigmaz() 

131 ) 

132 

133 

134def destroy(N, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

135 """annihilation operator 

136 

137 Args: 

138 N: Hilbert space size 

139 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

140 

141 Returns: 

142 annilation operator in Hilber Space of size N 

143 """ 

144 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

145 # Single superdiagonal at offset +1; Convention A: 1 leading zero. 

146 diags = jnp.zeros((1, N), dtype=jnp.float64) 

147 diags = diags.at[0, 1:].set(jnp.sqrt(jnp.arange(1, N, dtype=jnp.float64))) 

148 return _make_sparsedia(offsets=(1,), diags=diags) 

149 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1), implementation=implementation) 

150 

151 

152def create(N, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

153 """creation operator 

154 

155 Args: 

156 N: Hilbert space size 

157 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

158 

159 Returns: 

160 creation operator in Hilber Space of size N 

161 """ 

162 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

163 # Single subdiagonal at offset -1; Convention A: 1 trailing zero. 

164 diags = jnp.zeros((1, N), dtype=jnp.float64) 

165 diags = diags.at[0, :N - 1].set(jnp.sqrt(jnp.arange(1, N, dtype=jnp.float64))) 

166 return _make_sparsedia(offsets=(-1,), diags=diags) 

167 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1), implementation=implementation) 

168 

169 

170def num(N, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

171 """Number operator 

172 

173 Args: 

174 N: Hilbert Space size 

175 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

176 

177 Returns: 

178 number operator in Hilber Space of size N 

179 """ 

180 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

181 # Main diagonal only; no leading/trailing zeros needed (offset 0). 

182 diags = jnp.arange(N, dtype=jnp.float64).reshape(1, N) 

183 return _make_sparsedia(offsets=(0,), diags=diags) 

184 return Qarray.create(jnp.diag(jnp.arange(N)), implementation=implementation) 

185 

186 

187def identity(*args, implementation: QarrayImplType = QarrayImplType.DENSE, **kwargs) -> Qarray: 

188 """Identity matrix. 

189 

190 Args: 

191 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

192 

193 Returns: 

194 Identity matrix. 

195 """ 

196 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA: 

197 # jnp.eye(*args) is typically eye(N) or eye(N, N); extract N from args. 

198 n = args[0] if args else kwargs.get("N", kwargs.get("n", None)) 

199 if n is not None and (len(args) <= 1) and not kwargs: 

200 diags = jnp.ones((1, int(n)), dtype=jnp.float64) 

201 return _make_sparsedia(offsets=(0,), diags=diags) 

202 return Qarray.create(jnp.eye(*args, **kwargs), implementation=implementation) 

203 

204 

205qeye = identity 

206 

207def identity_like(A, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray: 

208 """Identity matrix with the same shape as A. 

209 

210 Args: 

211 A: Matrix. 

212 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

213 

214 Returns: 

215 Identity matrix with the same shape as A. 

216 """ 

217 space_dims = A.space_dims 

218 total_dim = prod(space_dims) 

219 return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims], implementation=implementation) 

220 

221 

222def displace(N, α) -> Qarray: 

223 """Displacement operator 

224 

225 Args: 

226 N: Hilbert Space Size 

227 α: Phase space displacement 

228 

229 Returns: 

230 Displace operator D(α) 

231 """ 

232 a = destroy(N) 

233 return (α * a.dag() - jnp.conj(α) * a).expm() 

234 

235def squeeze(N, z): 

236 """Single-mode Squeezing operator. 

237 

238 

239 Args: 

240 N: Hilbert Space Size 

241 z: squeezing parameter 

242 

243 Returns: 

244 Sqeezing operator 

245 """ 

246 

247 a = destroy(N) 

248 op = (1 / 2.0) * jnp.conj(z) * (a @ a) - (1 / 2.0) * z * (a.dag() @ a.dag()) 

249 return op.expm() 

250 

251 

252def squeezing_linear_to_dB(z): 

253 return 20 * jnp.log10(jnp.exp(jnp.abs(z))) 

254 

255def squeezing_dB_to_linear(z_dB): 

256 return jnp.log(10**(z_dB/20)) 

257 

258# States --------------------------------------------------------------------- 

259 

260 

261def basis(N: int, k: int, implementation: QarrayImplType = QarrayImplType.DENSE): 

262 """Creates a |k> (i.e. fock state) ket in a specified Hilbert Space. 

263 

264 Args: 

265 N: Hilbert space dimension 

266 k: fock number 

267 implementation: Qarray implementation type, e.g. "sparse" or "dense". 

268 

269 Returns: 

270 Fock State |k> 

271 """ 

272 return Qarray.create(one_hot(k, N).reshape(N, 1), implementation=implementation) 

273 

274def multi_mode_basis_set(Ns: List[int]) -> Qarray: 

275 """Creates a multi-mode basis set. 

276 

277 Args: 

278 Ns: List of Hilbert space dimensions for each mode. 

279 

280 Returns: 

281 Multi-mode basis set. 

282 """ 

283 data = jnp.eye(prod(Ns)) 

284 dims = (tuple(Ns), tuple([1 for _ in Ns])) 

285 return Qarray.create(data, dims=dims, bdims=(prod(Ns),)) 

286 

287 

288def coherent(N: int, α: complex) -> Qarray: 

289 """Coherent state. 

290 

291 Args: 

292 N: Hilbert Space Size. 

293 α: coherent state amplitude. 

294 

295 Return: 

296 Coherent state |α⟩. 

297 """ 

298 return displace(N, α) @ basis(N, 0) 

299 

300 

301def thermal_dm(N: int, n: float) -> Qarray: 

302 """Thermal state. 

303 

304 Args: 

305 N: Hilbert Space Size. 

306 n: average photon number. 

307 

308 Return: 

309 Thermal state. 

310 """ 

311 

312 beta = jnp.log(1 + 1 / n) 

313 

314 return Qarray.create( 

315 jnp.where( 

316 jnp.isposinf(beta), 

317 basis(N, 0).to_dm().data, 

318 jnp.diag(jnp.exp(-beta * jnp.linspace(0, N - 1, N))), 

319 ) 

320 ).unit() 

321 

322 

323def basis_like(A: Qarray, ks: List[int]) -> Qarray: 

324 """Creates a |k> (i.e. fock state) ket with the same space dims as A. 

325 

326 Args: 

327 A: state or operator. 

328 k: fock number. 

329 

330 Returns: 

331 Fock State |k> with the same space dims as A. 

332 """ 

333 space_dims = A.space_dims 

334 assert len(space_dims) == len(ks), "len(ks) must be equal to len(space_dims)" 

335 

336 kets = [] 

337 for j, k in enumerate(ks): 

338 kets.append(basis(space_dims[j], k)) 

339 return tensor(*kets)