Coverage for jaxquantum/core/qp_distributions.py: 95%

99 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1import jax.numpy as jnp 

2from jax import vmap, config 

3from jax.scipy.special import factorial 

4import jax 

5 

6config.update("jax_enable_x64", True) 

7 

8def wigner(psi, xvec, yvec, method="clenshaw", g=2): 

9 """Wigner function for a state vector or density matrix at points 

10 `xvec + i * yvec`. 

11 

12 Parameters 

13 ---------- 

14 

15 state : Qarray 

16 A state vector or density matrix. 

17 

18 xvec : array_like 

19 x-coordinates at which to calculate the Wigner function. 

20 

21 yvec : array_like 

22 y-coordinates at which to calculate the Wigner function. 

23 

24 g : float, default: 2 

25 Scaling factor for `a = 0.5 * g * (x + iy)`, default `g = 2`. 

26 The value of `g` is related to the value of `hbar` in the commutation 

27 relation `[x, y] = i * hbar` via `hbar=2/g^2`. 

28 

29 method : string {'clenshaw', 'iterative', 'laguerre', 'fft'}, default: 'clenshaw' 

30 Only 'clenshaw' is currently supported. 

31 Select method 'clenshaw' 'iterative', 'laguerre', or 'fft', where 'clenshaw' 

32 and 'iterative' use an iterative method to evaluate the Wigner functions for density 

33 matrices :math:`|m><n|`, while 'laguerre' uses the Laguerre polynomials 

34 in scipy for the same task. The 'fft' method evaluates the Fourier 

35 transform of the density matrix. The 'iterative' method is default, and 

36 in general recommended, but the 'laguerre' method is more efficient for 

37 very sparse density matrices (e.g., superpositions of Fock states in a 

38 large Hilbert space). The 'clenshaw' method is the preferred method for 

39 dealing with density matrices that have a large number of excitations 

40 (>~50). 'clenshaw' is a fast and numerically stable method. 

41 

42 Returns 

43 ------- 

44 

45 W : array 

46 Values representing the Wigner function calculated over the specified 

47 range [xvec,yvec]. 

48 

49 

50 References 

51 ---------- 

52 

53 Ulf Leonhardt, 

54 Measuring the Quantum State of Light, (Cambridge University Press, 1997) 

55 

56 """ 

57 

58 if not (psi.is_vec() or psi.is_dm()): 

59 raise TypeError("Input state is not a valid operator.") 

60 

61 if method == "fft": 

62 raise NotImplementedError("Only the 'clenshaw' method is implemented.") 

63 

64 if method == "iterative": 

65 raise NotImplementedError("Only the 'clenshaw' method is implemented.") 

66 

67 elif method == "laguerre": 

68 raise NotImplementedError("Only the 'clenshaw' method is implemented.") 

69 

70 elif method == "clenshaw": 

71 rho = psi.to_dm() 

72 rho = rho.data 

73 

74 vmapped_wigner_clenshaw = [_wigner_clenshaw] 

75 

76 for _ in rho.shape[:-2]: 

77 vmapped_wigner_clenshaw.append( 

78 vmap( 

79 vmapped_wigner_clenshaw[-1], 

80 in_axes=(0, None, None, None), 

81 out_axes=0, 

82 ) 

83 ) 

84 return vmapped_wigner_clenshaw[-1](rho, xvec, yvec, g) 

85 

86 else: 

87 raise TypeError("method must be either 'iterative', 'laguerre', or 'fft'.") 

88 

89 

90def _wigner_clenshaw(rho, xvec, yvec, g): 

91 r""" 

92 Using Clenshaw summation - numerically stable and efficient 

93 iterative algorithm to evaluate polynomial series. 

94 

95 The Wigner function is calculated as 

96 :math:`W = e^(-0.5*x^2)/pi * \sum_{L} c_L (2x)^L / \sqrt(L!)` where 

97 :math:`c_L = \sum_n \rho_{n,L+n} LL_n^L` where 

98 :math:`LL_n^L = (-1)^n \sqrt(L!n!/(L+n)!) LaguerreL[n,L,x]` 

99 Heavily inspired by Qutip and Dynamiqs 

100 https://github.com/dynamiqs/dynamiqs 

101 https://github.com/qutip/qutip 

102 """ 

103 

104 M = jnp.prod(rho.shape[0]) 

105 X, Y = jnp.meshgrid(xvec, yvec) 

106 A = 0.5 * g * (X + 1.0j * Y) 

107 B = jnp.abs(2*A) 

108 

109 B *= B 

110 

111 w0 = (2 * rho[0, -1]) * jnp.ones_like(A) 

112 

113 # calculation of \sum_{L} c_L (2x)^L / \sqrt(L!) 

114 # using Horner's method 

115 

116 rho = rho * (2 * jnp.ones((M, M)) - jnp.diag(jnp.ones(M))) 

117 def loop(i: int, w: jax.Array) -> jax.Array: 

118 i = M - 2 - i 

119 w = w * (2 * A * (i + 1) ** (-0.5)) 

120 return w + _wig_laguerre_val(i, B, rho, M) 

121 

122 w = jax.lax.fori_loop(0, M - 1, loop, w0) 

123 

124 return w.real * jnp.exp(-B * 0.5) * (g * g * 0.5 / jnp.pi) 

125 

126def _extract_diag_element(rho: jnp.array, L: int, n:int): 

127 """" 

128 Extract element at index n from diagonal L of matrix rho. 

129 Heavily inspired from https://github.com/dynamiqs/dynamiqs 

130 """ 

131 N = rho.shape[0] 

132 n = jax.lax.select(n < 0, N - jnp.abs(L) - jnp.abs(n), n) 

133 row = jnp.maximum(-L, 0) + n 

134 col = jnp.maximum(L, 0) + n 

135 return rho[row, col] 

136 

137def _wig_laguerre_val(L, x, rho, N): 

138 r""" 

139 Evaluate Laguerre polynomials. 

140 Implementation in Jax from https://github.com/dynamiqs/dynamiqs 

141 """ 

142 

143 def len_c_1(): 

144 return _extract_diag_element(rho, L, 0) * jnp.ones_like(x) 

145 

146 def len_c_2(): 

147 c0 = _extract_diag_element(rho, L, 0) 

148 c1 = _extract_diag_element(rho, L, 1) 

149 return (c0 - c1 * (L + 1 - x) * (L + 1) ** (-0.5)) * jnp.ones_like(x) 

150 

151 def len_c_other(): 

152 cm2 = _extract_diag_element(rho, L, -2) 

153 cm1 = _extract_diag_element(rho, L, -1) 

154 y0 = cm2 * jnp.ones_like(x) 

155 y1 = cm1 * jnp.ones_like(x) 

156 

157 def loop(j: int, args: tuple[jax.Array, jax.Array]) -> tuple[ 

158 jax.Array, jax.Array]: 

159 def body() -> tuple[jax.Array, jax.Array]: 

160 k = N + 1 - L - j 

161 y0, y1 = args 

162 ckm1 = _extract_diag_element(rho, L, -j) 

163 y0, y1 = ( 

164 ckm1 - y1 * (k * (L + k) / ((L + k + 1) * (k + 1))) ** 0.5, 

165 y0 - y1 * (L + 2 * k - x + 1) * ( 

166 (L + k + 1) * (k + 1)) ** -0.5, 

167 ) 

168 

169 return y0, y1 

170 

171 return jax.lax.cond(j >= N + 1 - L, lambda: args, body) 

172 

173 y0, y1 = jax.lax.fori_loop(3, N + 1, loop, (y0, y1)) 

174 

175 return y0 - y1 * (L + 1 - x) * (L + 1) ** (-0.5) 

176 

177 

178 return jax.lax.cond(N - L == 1, len_c_1, lambda: jax.lax.cond(N - L == 2, 

179 len_c_2, 

180 len_c_other)) 

181 

182 

183def qfunc(psi, xvec, yvec, g=2): 

184 r""" 

185 Husimi-Q function of a given state vector or density matrix at phase-space 

186 points ``0.5 * g * (xvec + i*yvec)``. 

187 

188 Parameters 

189 ---------- 

190 state : Qarray 

191 A state vector or density matrix. This cannot have tensor-product 

192 structure. 

193 

194 xvec, yvec : array_like 

195 x- and y-coordinates at which to calculate the Husimi-Q function. 

196 

197 g : float, default: 2 

198 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

199 related to the value of :math:`\hbar` in the commutation relation 

200 :math:`[x,\,y] = i\hbar` via :math:`\hbar=2/g^2`. 

201 

202 Returns 

203 ------- 

204 jnp.ndarray 

205 Values representing the Husimi-Q function calculated over the specified 

206 range ``[xvec, yvec]``. 

207 

208 """ 

209 

210 alpha_grid, prefactor = _qfunc_coherent_grid(xvec, yvec, g) 

211 

212 if psi.is_vec(): 

213 psi = psi.to_ket() 

214 

215 def _compute_qfunc(psi, alpha_grid, prefactor, g): 

216 out = _qfunc_iterative_single(psi, alpha_grid, prefactor, g) 

217 out /= jnp.pi 

218 return out 

219 else: 

220 

221 def _compute_qfunc(psi, alpha_grid, prefactor, g): 

222 values, vectors = jnp.linalg.eigh(psi) 

223 vectors = vectors.T 

224 out = values[0] * _qfunc_iterative_single( 

225 vectors[0], alpha_grid, prefactor, g 

226 ) 

227 for value, vector in zip(values[1:], vectors[1:]): 

228 out += value * _qfunc_iterative_single(vector, alpha_grid, prefactor, g) 

229 out /= jnp.pi 

230 

231 return out 

232 

233 psi = psi.data 

234 

235 vmapped_compute_qfunc = [_compute_qfunc] 

236 

237 for _ in psi.shape[:-2]: 

238 vmapped_compute_qfunc.append( 

239 vmap( 

240 vmapped_compute_qfunc[-1], 

241 in_axes=(0, None, None, None), 

242 out_axes=0, 

243 ) 

244 ) 

245 return vmapped_compute_qfunc[-1](psi, alpha_grid, prefactor, g) 

246 

247 

248def _qfunc_iterative_single( 

249 vector, 

250 grid, 

251 prefactor, 

252 g, 

253): 

254 r""" 

255 Get the Q function (without the :math:`\pi` scaling factor) of a single 

256 state vector, using the iterative algorithm which recomputes the powers of 

257 the coherent-state matrix. 

258 """ 

259 vector = vector.squeeze() 

260 ns = jnp.arange(vector.shape[-1]) 

261 out = jnp.polyval( 

262 (vector / jnp.sqrt(factorial(ns)))[::-1], 

263 grid, 

264 ) 

265 out *= prefactor 

266 return jnp.abs(out) ** 2 

267 

268 

269def _qfunc_coherent_grid(xvec, yvec, g): 

270 x, y = jnp.meshgrid(0.5 * g * xvec, 0.5 * g * yvec) 

271 grid = jnp.empty(x.shape, dtype=jnp.complex128) 

272 grid += x 

273 # We produce the adjoint of the coherent states to save an operation 

274 # later when computing dot products, hence the negative imaginary part. 

275 grid += -y * 1.0j 

276 prefactor = jnp.exp(-0.5 * (x * x + y * y)).astype(jnp.complex128) 

277 return grid, prefactor