Coverage for jaxquantum/core/qp_distributions.py: 95%
98 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1import jax.numpy as jnp
2from jax import vmap
3from jax.scipy.special import factorial
4import jax
6def wigner(psi, xvec, yvec, method="clenshaw", g=2):
7 """Wigner function for a state vector or density matrix at points
8 `xvec + i * yvec`.
10 Parameters
11 ----------
13 state : Qarray
14 A state vector or density matrix.
16 xvec : array_like
17 x-coordinates at which to calculate the Wigner function.
19 yvec : array_like
20 y-coordinates at which to calculate the Wigner function.
22 g : float, default: 2
23 Scaling factor for `a = 0.5 * g * (x + iy)`, default `g = 2`.
24 The value of `g` is related to the value of `hbar` in the commutation
25 relation `[x, y] = i * hbar` via `hbar=2/g^2`.
27 method : string {'clenshaw', 'iterative', 'laguerre', 'fft'}, default: 'clenshaw'
28 Only 'clenshaw' is currently supported.
29 Select method 'clenshaw' 'iterative', 'laguerre', or 'fft', where 'clenshaw'
30 and 'iterative' use an iterative method to evaluate the Wigner functions for density
31 matrices :math:`|m><n|`, while 'laguerre' uses the Laguerre polynomials
32 in scipy for the same task. The 'fft' method evaluates the Fourier
33 transform of the density matrix. The 'iterative' method is default, and
34 in general recommended, but the 'laguerre' method is more efficient for
35 very sparse density matrices (e.g., superpositions of Fock states in a
36 large Hilbert space). The 'clenshaw' method is the preferred method for
37 dealing with density matrices that have a large number of excitations
38 (>~50). 'clenshaw' is a fast and numerically stable method.
40 Returns
41 -------
43 W : array
44 Values representing the Wigner function calculated over the specified
45 range [xvec,yvec].
48 References
49 ----------
51 Ulf Leonhardt,
52 Measuring the Quantum State of Light, (Cambridge University Press, 1997)
54 """
56 if not (psi.is_vec() or psi.is_dm()):
57 raise TypeError("Input state is not a valid operator.")
59 if method == "fft":
60 raise NotImplementedError("Only the 'clenshaw' method is implemented.")
62 if method == "iterative":
63 raise NotImplementedError("Only the 'clenshaw' method is implemented.")
65 elif method == "laguerre":
66 raise NotImplementedError("Only the 'clenshaw' method is implemented.")
68 elif method == "clenshaw":
69 rho = psi.to_dm()
70 rho = rho.data
72 vmapped_wigner_clenshaw = [_wigner_clenshaw]
74 for _ in rho.shape[:-2]:
75 vmapped_wigner_clenshaw.append(
76 vmap(
77 vmapped_wigner_clenshaw[-1],
78 in_axes=(0, None, None, None),
79 out_axes=0,
80 )
81 )
82 return vmapped_wigner_clenshaw[-1](rho, xvec, yvec, g)
84 else:
85 raise TypeError("method must be either 'iterative', 'laguerre', or 'fft'.")
88def _wigner_clenshaw(rho, xvec, yvec, g=jnp.sqrt(2)):
89 r"""
90 Using Clenshaw summation - numerically stable and efficient
91 iterative algorithm to evaluate polynomial series.
93 The Wigner function is calculated as
94 :math:`W = e^(-0.5*x^2)/pi * \sum_{L} c_L (2x)^L / \sqrt(L!)` where
95 :math:`c_L = \sum_n \rho_{n,L+n} LL_n^L` where
96 :math:`LL_n^L = (-1)^n \sqrt(L!n!/(L+n)!) LaguerreL[n,L,x]`
97 Heavily inspired by Qutip and Dynamiqs
98 https://github.com/dynamiqs/dynamiqs
99 https://github.com/qutip/qutip
100 """
102 M = jnp.prod(rho.shape[0])
103 X, Y = jnp.meshgrid(xvec, yvec)
104 A = 0.5 * g * (X + 1.0j * Y)
105 B = jnp.abs(2*A)
107 B *= B
109 w0 = (2 * rho[0, -1]) * jnp.ones_like(A)
111 # calculation of \sum_{L} c_L (2x)^L / \sqrt(L!)
112 # using Horner's method
114 rho = rho * (2 * jnp.ones((M, M)) - jnp.diag(jnp.ones(M)))
115 def loop(i: int, w: jax.Array) -> jax.Array:
116 i = M - 2 - i
117 w = w * (2 * A * (i + 1) ** (-0.5))
118 return w + _wig_laguerre_val(i, B, rho, M)
120 w = jax.lax.fori_loop(0, M - 1, loop, w0)
122 return w.real * jnp.exp(-B * 0.5) * (g * g * 0.5 / jnp.pi)
124def _extract_diag_element(rho: jnp.array, L: int, n:int):
125 """"
126 Extract element at index n from diagonal L of matrix rho.
127 Heavily inspired from https://github.com/dynamiqs/dynamiqs
128 """
129 N = rho.shape[0]
130 n = jax.lax.select(n < 0, N - jnp.abs(L) - jnp.abs(n), n)
131 row = jnp.maximum(-L, 0) + n
132 col = jnp.maximum(L, 0) + n
133 return rho[row, col]
135def _wig_laguerre_val(L, x, rho, N):
136 r"""
137 Evaluate Laguerre polynomials.
138 Implementation in Jax from https://github.com/dynamiqs/dynamiqs
139 """
141 def len_c_1():
142 return _extract_diag_element(rho, L, 0) * jnp.ones_like(x)
144 def len_c_2():
145 c0 = _extract_diag_element(rho, L, 0)
146 c1 = _extract_diag_element(rho, L, 1)
147 return (c0 - c1 * (L + 1 - x) * (L + 1) ** (-0.5)) * jnp.ones_like(x)
149 def len_c_other():
150 cm2 = _extract_diag_element(rho, L, -2)
151 cm1 = _extract_diag_element(rho, L, -1)
152 y0 = cm2 * jnp.ones_like(x)
153 y1 = cm1 * jnp.ones_like(x)
155 def loop(j: int, args: tuple[jax.Array, jax.Array]) -> tuple[
156 jax.Array, jax.Array]:
157 def body() -> tuple[jax.Array, jax.Array]:
158 k = N + 1 - L - j
159 y0, y1 = args
160 ckm1 = _extract_diag_element(rho, L, -j)
161 y0, y1 = (
162 ckm1 - y1 * (k * (L + k) / ((L + k + 1) * (k + 1))) ** 0.5,
163 y0 - y1 * (L + 2 * k - x + 1) * (
164 (L + k + 1) * (k + 1)) ** -0.5,
165 )
167 return y0, y1
169 return jax.lax.cond(j >= N + 1 - L, lambda: args, body)
171 y0, y1 = jax.lax.fori_loop(3, N + 1, loop, (y0, y1))
173 return y0 - y1 * (L + 1 - x) * (L + 1) ** (-0.5)
176 return jax.lax.cond(N - L == 1, len_c_1, lambda: jax.lax.cond(N - L == 2,
177 len_c_2,
178 len_c_other))
181def qfunc(psi, xvec, yvec, g=2):
182 r"""
183 Husimi-Q function of a given state vector or density matrix at phase-space
184 points ``0.5 * g * (xvec + i*yvec)``.
186 Parameters
187 ----------
188 state : Qarray
189 A state vector or density matrix. This cannot have tensor-product
190 structure.
192 xvec, yvec : array_like
193 x- and y-coordinates at which to calculate the Husimi-Q function.
195 g : float, default: 2
196 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
197 related to the value of :math:`\hbar` in the commutation relation
198 :math:`[x,\,y] = i\hbar` via :math:`\hbar=2/g^2`.
200 Returns
201 -------
202 jnp.ndarray
203 Values representing the Husimi-Q function calculated over the specified
204 range ``[xvec, yvec]``.
206 """
208 alpha_grid, prefactor = _qfunc_coherent_grid(xvec, yvec, g)
210 if psi.is_vec():
211 psi = psi.to_ket()
213 def _compute_qfunc(psi, alpha_grid, prefactor, g):
214 out = _qfunc_iterative_single(psi, alpha_grid, prefactor, g)
215 out /= jnp.pi
216 return out
217 else:
219 def _compute_qfunc(psi, alpha_grid, prefactor, g):
220 values, vectors = jnp.linalg.eigh(psi)
221 vectors = vectors.T
222 out = values[0] * _qfunc_iterative_single(
223 vectors[0], alpha_grid, prefactor, g
224 )
225 for value, vector in zip(values[1:], vectors[1:]):
226 out += value * _qfunc_iterative_single(vector, alpha_grid, prefactor, g)
227 out /= jnp.pi
229 return out
231 psi = psi.data
233 vmapped_compute_qfunc = [_compute_qfunc]
235 for _ in psi.shape[:-2]:
236 vmapped_compute_qfunc.append(
237 vmap(
238 vmapped_compute_qfunc[-1],
239 in_axes=(0, None, None, None),
240 out_axes=0,
241 )
242 )
243 return vmapped_compute_qfunc[-1](psi, alpha_grid, prefactor, g)
246def _qfunc_iterative_single(
247 vector,
248 grid,
249 prefactor,
250 g,
251):
252 r"""
253 Get the Q function (without the :math:`\pi` scaling factor) of a single
254 state vector, using the iterative algorithm which recomputes the powers of
255 the coherent-state matrix.
256 """
257 vector = vector.squeeze()
258 ns = jnp.arange(vector.shape[-1])
259 out = jnp.polyval(
260 (vector / jnp.sqrt(factorial(ns)))[::-1],
261 grid,
262 )
263 out *= prefactor
264 return jnp.abs(out) ** 2
267def _qfunc_coherent_grid(xvec, yvec, g):
268 x, y = jnp.meshgrid(0.5 * g * xvec, 0.5 * g * yvec)
269 grid = jnp.empty(x.shape, dtype=jnp.complex128)
270 grid += x
271 # We produce the adjoint of the coherent states to save an operation
272 # later when computing dot products, hence the negative imaginary part.
273 grid += -y * 1.0j
274 prefactor = jnp.exp(-0.5 * (x * x + y * y)).astype(jnp.complex128)
275 return grid, prefactor