Coverage for jaxquantum/core/qp_distributions.py: 95%

98 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1import jax.numpy as jnp 

2from jax import vmap 

3from jax.scipy.special import factorial 

4import jax 

5 

6def wigner(psi, xvec, yvec, method="clenshaw", g=2): 

7 """Wigner function for a state vector or density matrix at points 

8 `xvec + i * yvec`. 

9 

10 Parameters 

11 ---------- 

12 

13 state : Qarray 

14 A state vector or density matrix. 

15 

16 xvec : array_like 

17 x-coordinates at which to calculate the Wigner function. 

18 

19 yvec : array_like 

20 y-coordinates at which to calculate the Wigner function. 

21 

22 g : float, default: 2 

23 Scaling factor for `a = 0.5 * g * (x + iy)`, default `g = 2`. 

24 The value of `g` is related to the value of `hbar` in the commutation 

25 relation `[x, y] = i * hbar` via `hbar=2/g^2`. 

26 

27 method : string {'clenshaw', 'iterative', 'laguerre', 'fft'}, default: 'clenshaw' 

28 Only 'clenshaw' is currently supported. 

29 Select method 'clenshaw' 'iterative', 'laguerre', or 'fft', where 'clenshaw' 

30 and 'iterative' use an iterative method to evaluate the Wigner functions for density 

31 matrices :math:`|m><n|`, while 'laguerre' uses the Laguerre polynomials 

32 in scipy for the same task. The 'fft' method evaluates the Fourier 

33 transform of the density matrix. The 'iterative' method is default, and 

34 in general recommended, but the 'laguerre' method is more efficient for 

35 very sparse density matrices (e.g., superpositions of Fock states in a 

36 large Hilbert space). The 'clenshaw' method is the preferred method for 

37 dealing with density matrices that have a large number of excitations 

38 (>~50). 'clenshaw' is a fast and numerically stable method. 

39 

40 Returns 

41 ------- 

42 

43 W : array 

44 Values representing the Wigner function calculated over the specified 

45 range [xvec,yvec]. 

46 

47 

48 References 

49 ---------- 

50 

51 Ulf Leonhardt, 

52 Measuring the Quantum State of Light, (Cambridge University Press, 1997) 

53 

54 """ 

55 

56 if not (psi.is_vec() or psi.is_dm()): 

57 raise TypeError("Input state is not a valid operator.") 

58 

59 if method == "fft": 

60 raise NotImplementedError("Only the 'clenshaw' method is implemented.") 

61 

62 if method == "iterative": 

63 raise NotImplementedError("Only the 'clenshaw' method is implemented.") 

64 

65 elif method == "laguerre": 

66 raise NotImplementedError("Only the 'clenshaw' method is implemented.") 

67 

68 elif method == "clenshaw": 

69 rho = psi.to_dm() 

70 rho = rho.data 

71 

72 vmapped_wigner_clenshaw = [_wigner_clenshaw] 

73 

74 for _ in rho.shape[:-2]: 

75 vmapped_wigner_clenshaw.append( 

76 vmap( 

77 vmapped_wigner_clenshaw[-1], 

78 in_axes=(0, None, None, None), 

79 out_axes=0, 

80 ) 

81 ) 

82 return vmapped_wigner_clenshaw[-1](rho, xvec, yvec, g) 

83 

84 else: 

85 raise TypeError("method must be either 'iterative', 'laguerre', or 'fft'.") 

86 

87 

88def _wigner_clenshaw(rho, xvec, yvec, g=jnp.sqrt(2)): 

89 r""" 

90 Using Clenshaw summation - numerically stable and efficient 

91 iterative algorithm to evaluate polynomial series. 

92 

93 The Wigner function is calculated as 

94 :math:`W = e^(-0.5*x^2)/pi * \sum_{L} c_L (2x)^L / \sqrt(L!)` where 

95 :math:`c_L = \sum_n \rho_{n,L+n} LL_n^L` where 

96 :math:`LL_n^L = (-1)^n \sqrt(L!n!/(L+n)!) LaguerreL[n,L,x]` 

97 Heavily inspired by Qutip and Dynamiqs 

98 https://github.com/dynamiqs/dynamiqs 

99 https://github.com/qutip/qutip 

100 """ 

101 

102 M = jnp.prod(rho.shape[0]) 

103 X, Y = jnp.meshgrid(xvec, yvec) 

104 A = 0.5 * g * (X + 1.0j * Y) 

105 B = jnp.abs(2*A) 

106 

107 B *= B 

108 

109 w0 = (2 * rho[0, -1]) * jnp.ones_like(A) 

110 

111 # calculation of \sum_{L} c_L (2x)^L / \sqrt(L!) 

112 # using Horner's method 

113 

114 rho = rho * (2 * jnp.ones((M, M)) - jnp.diag(jnp.ones(M))) 

115 def loop(i: int, w: jax.Array) -> jax.Array: 

116 i = M - 2 - i 

117 w = w * (2 * A * (i + 1) ** (-0.5)) 

118 return w + _wig_laguerre_val(i, B, rho, M) 

119 

120 w = jax.lax.fori_loop(0, M - 1, loop, w0) 

121 

122 return w.real * jnp.exp(-B * 0.5) * (g * g * 0.5 / jnp.pi) 

123 

124def _extract_diag_element(rho: jnp.array, L: int, n:int): 

125 """" 

126 Extract element at index n from diagonal L of matrix rho. 

127 Heavily inspired from https://github.com/dynamiqs/dynamiqs 

128 """ 

129 N = rho.shape[0] 

130 n = jax.lax.select(n < 0, N - jnp.abs(L) - jnp.abs(n), n) 

131 row = jnp.maximum(-L, 0) + n 

132 col = jnp.maximum(L, 0) + n 

133 return rho[row, col] 

134 

135def _wig_laguerre_val(L, x, rho, N): 

136 r""" 

137 Evaluate Laguerre polynomials. 

138 Implementation in Jax from https://github.com/dynamiqs/dynamiqs 

139 """ 

140 

141 def len_c_1(): 

142 return _extract_diag_element(rho, L, 0) * jnp.ones_like(x) 

143 

144 def len_c_2(): 

145 c0 = _extract_diag_element(rho, L, 0) 

146 c1 = _extract_diag_element(rho, L, 1) 

147 return (c0 - c1 * (L + 1 - x) * (L + 1) ** (-0.5)) * jnp.ones_like(x) 

148 

149 def len_c_other(): 

150 cm2 = _extract_diag_element(rho, L, -2) 

151 cm1 = _extract_diag_element(rho, L, -1) 

152 y0 = cm2 * jnp.ones_like(x) 

153 y1 = cm1 * jnp.ones_like(x) 

154 

155 def loop(j: int, args: tuple[jax.Array, jax.Array]) -> tuple[ 

156 jax.Array, jax.Array]: 

157 def body() -> tuple[jax.Array, jax.Array]: 

158 k = N + 1 - L - j 

159 y0, y1 = args 

160 ckm1 = _extract_diag_element(rho, L, -j) 

161 y0, y1 = ( 

162 ckm1 - y1 * (k * (L + k) / ((L + k + 1) * (k + 1))) ** 0.5, 

163 y0 - y1 * (L + 2 * k - x + 1) * ( 

164 (L + k + 1) * (k + 1)) ** -0.5, 

165 ) 

166 

167 return y0, y1 

168 

169 return jax.lax.cond(j >= N + 1 - L, lambda: args, body) 

170 

171 y0, y1 = jax.lax.fori_loop(3, N + 1, loop, (y0, y1)) 

172 

173 return y0 - y1 * (L + 1 - x) * (L + 1) ** (-0.5) 

174 

175 

176 return jax.lax.cond(N - L == 1, len_c_1, lambda: jax.lax.cond(N - L == 2, 

177 len_c_2, 

178 len_c_other)) 

179 

180 

181def qfunc(psi, xvec, yvec, g=2): 

182 r""" 

183 Husimi-Q function of a given state vector or density matrix at phase-space 

184 points ``0.5 * g * (xvec + i*yvec)``. 

185 

186 Parameters 

187 ---------- 

188 state : Qarray 

189 A state vector or density matrix. This cannot have tensor-product 

190 structure. 

191 

192 xvec, yvec : array_like 

193 x- and y-coordinates at which to calculate the Husimi-Q function. 

194 

195 g : float, default: 2 

196 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

197 related to the value of :math:`\hbar` in the commutation relation 

198 :math:`[x,\,y] = i\hbar` via :math:`\hbar=2/g^2`. 

199 

200 Returns 

201 ------- 

202 jnp.ndarray 

203 Values representing the Husimi-Q function calculated over the specified 

204 range ``[xvec, yvec]``. 

205 

206 """ 

207 

208 alpha_grid, prefactor = _qfunc_coherent_grid(xvec, yvec, g) 

209 

210 if psi.is_vec(): 

211 psi = psi.to_ket() 

212 

213 def _compute_qfunc(psi, alpha_grid, prefactor, g): 

214 out = _qfunc_iterative_single(psi, alpha_grid, prefactor, g) 

215 out /= jnp.pi 

216 return out 

217 else: 

218 

219 def _compute_qfunc(psi, alpha_grid, prefactor, g): 

220 values, vectors = jnp.linalg.eigh(psi) 

221 vectors = vectors.T 

222 out = values[0] * _qfunc_iterative_single( 

223 vectors[0], alpha_grid, prefactor, g 

224 ) 

225 for value, vector in zip(values[1:], vectors[1:]): 

226 out += value * _qfunc_iterative_single(vector, alpha_grid, prefactor, g) 

227 out /= jnp.pi 

228 

229 return out 

230 

231 psi = psi.data 

232 

233 vmapped_compute_qfunc = [_compute_qfunc] 

234 

235 for _ in psi.shape[:-2]: 

236 vmapped_compute_qfunc.append( 

237 vmap( 

238 vmapped_compute_qfunc[-1], 

239 in_axes=(0, None, None, None), 

240 out_axes=0, 

241 ) 

242 ) 

243 return vmapped_compute_qfunc[-1](psi, alpha_grid, prefactor, g) 

244 

245 

246def _qfunc_iterative_single( 

247 vector, 

248 grid, 

249 prefactor, 

250 g, 

251): 

252 r""" 

253 Get the Q function (without the :math:`\pi` scaling factor) of a single 

254 state vector, using the iterative algorithm which recomputes the powers of 

255 the coherent-state matrix. 

256 """ 

257 vector = vector.squeeze() 

258 ns = jnp.arange(vector.shape[-1]) 

259 out = jnp.polyval( 

260 (vector / jnp.sqrt(factorial(ns)))[::-1], 

261 grid, 

262 ) 

263 out *= prefactor 

264 return jnp.abs(out) ** 2 

265 

266 

267def _qfunc_coherent_grid(xvec, yvec, g): 

268 x, y = jnp.meshgrid(0.5 * g * xvec, 0.5 * g * yvec) 

269 grid = jnp.empty(x.shape, dtype=jnp.complex128) 

270 grid += x 

271 # We produce the adjoint of the coherent states to save an operation 

272 # later when computing dot products, hence the negative imaginary part. 

273 grid += -y * 1.0j 

274 prefactor = jnp.exp(-0.5 * (x * x + y * y)).astype(jnp.complex128) 

275 return grid, prefactor