Coverage for jaxquantum/core/qp_distributions.py: 95%
99 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1import jax.numpy as jnp
2from jax import vmap, config
3from jax.scipy.special import factorial
4import jax
6config.update("jax_enable_x64", True)
8def wigner(psi, xvec, yvec, method="clenshaw", g=2):
9 """Wigner function for a state vector or density matrix at points
10 `xvec + i * yvec`.
12 Parameters
13 ----------
15 state : Qarray
16 A state vector or density matrix.
18 xvec : array_like
19 x-coordinates at which to calculate the Wigner function.
21 yvec : array_like
22 y-coordinates at which to calculate the Wigner function.
24 g : float, default: 2
25 Scaling factor for `a = 0.5 * g * (x + iy)`, default `g = 2`.
26 The value of `g` is related to the value of `hbar` in the commutation
27 relation `[x, y] = i * hbar` via `hbar=2/g^2`.
29 method : string {'clenshaw', 'iterative', 'laguerre', 'fft'}, default: 'clenshaw'
30 Only 'clenshaw' is currently supported.
31 Select method 'clenshaw' 'iterative', 'laguerre', or 'fft', where 'clenshaw'
32 and 'iterative' use an iterative method to evaluate the Wigner functions for density
33 matrices :math:`|m><n|`, while 'laguerre' uses the Laguerre polynomials
34 in scipy for the same task. The 'fft' method evaluates the Fourier
35 transform of the density matrix. The 'iterative' method is default, and
36 in general recommended, but the 'laguerre' method is more efficient for
37 very sparse density matrices (e.g., superpositions of Fock states in a
38 large Hilbert space). The 'clenshaw' method is the preferred method for
39 dealing with density matrices that have a large number of excitations
40 (>~50). 'clenshaw' is a fast and numerically stable method.
42 Returns
43 -------
45 W : array
46 Values representing the Wigner function calculated over the specified
47 range [xvec,yvec].
50 References
51 ----------
53 Ulf Leonhardt,
54 Measuring the Quantum State of Light, (Cambridge University Press, 1997)
56 """
58 if not (psi.is_vec() or psi.is_dm()):
59 raise TypeError("Input state is not a valid operator.")
61 if method == "fft":
62 raise NotImplementedError("Only the 'clenshaw' method is implemented.")
64 if method == "iterative":
65 raise NotImplementedError("Only the 'clenshaw' method is implemented.")
67 elif method == "laguerre":
68 raise NotImplementedError("Only the 'clenshaw' method is implemented.")
70 elif method == "clenshaw":
71 rho = psi.to_dm()
72 rho = rho.data
74 vmapped_wigner_clenshaw = [_wigner_clenshaw]
76 for _ in rho.shape[:-2]:
77 vmapped_wigner_clenshaw.append(
78 vmap(
79 vmapped_wigner_clenshaw[-1],
80 in_axes=(0, None, None, None),
81 out_axes=0,
82 )
83 )
84 return vmapped_wigner_clenshaw[-1](rho, xvec, yvec, g)
86 else:
87 raise TypeError("method must be either 'iterative', 'laguerre', or 'fft'.")
90def _wigner_clenshaw(rho, xvec, yvec, g):
91 r"""
92 Using Clenshaw summation - numerically stable and efficient
93 iterative algorithm to evaluate polynomial series.
95 The Wigner function is calculated as
96 :math:`W = e^(-0.5*x^2)/pi * \sum_{L} c_L (2x)^L / \sqrt(L!)` where
97 :math:`c_L = \sum_n \rho_{n,L+n} LL_n^L` where
98 :math:`LL_n^L = (-1)^n \sqrt(L!n!/(L+n)!) LaguerreL[n,L,x]`
99 Heavily inspired by Qutip and Dynamiqs
100 https://github.com/dynamiqs/dynamiqs
101 https://github.com/qutip/qutip
102 """
104 M = jnp.prod(rho.shape[0])
105 X, Y = jnp.meshgrid(xvec, yvec)
106 A = 0.5 * g * (X + 1.0j * Y)
107 B = jnp.abs(2*A)
109 B *= B
111 w0 = (2 * rho[0, -1]) * jnp.ones_like(A)
113 # calculation of \sum_{L} c_L (2x)^L / \sqrt(L!)
114 # using Horner's method
116 rho = rho * (2 * jnp.ones((M, M)) - jnp.diag(jnp.ones(M)))
117 def loop(i: int, w: jax.Array) -> jax.Array:
118 i = M - 2 - i
119 w = w * (2 * A * (i + 1) ** (-0.5))
120 return w + _wig_laguerre_val(i, B, rho, M)
122 w = jax.lax.fori_loop(0, M - 1, loop, w0)
124 return w.real * jnp.exp(-B * 0.5) * (g * g * 0.5 / jnp.pi)
126def _extract_diag_element(rho: jnp.array, L: int, n:int):
127 """"
128 Extract element at index n from diagonal L of matrix rho.
129 Heavily inspired from https://github.com/dynamiqs/dynamiqs
130 """
131 N = rho.shape[0]
132 n = jax.lax.select(n < 0, N - jnp.abs(L) - jnp.abs(n), n)
133 row = jnp.maximum(-L, 0) + n
134 col = jnp.maximum(L, 0) + n
135 return rho[row, col]
137def _wig_laguerre_val(L, x, rho, N):
138 r"""
139 Evaluate Laguerre polynomials.
140 Implementation in Jax from https://github.com/dynamiqs/dynamiqs
141 """
143 def len_c_1():
144 return _extract_diag_element(rho, L, 0) * jnp.ones_like(x)
146 def len_c_2():
147 c0 = _extract_diag_element(rho, L, 0)
148 c1 = _extract_diag_element(rho, L, 1)
149 return (c0 - c1 * (L + 1 - x) * (L + 1) ** (-0.5)) * jnp.ones_like(x)
151 def len_c_other():
152 cm2 = _extract_diag_element(rho, L, -2)
153 cm1 = _extract_diag_element(rho, L, -1)
154 y0 = cm2 * jnp.ones_like(x)
155 y1 = cm1 * jnp.ones_like(x)
157 def loop(j: int, args: tuple[jax.Array, jax.Array]) -> tuple[
158 jax.Array, jax.Array]:
159 def body() -> tuple[jax.Array, jax.Array]:
160 k = N + 1 - L - j
161 y0, y1 = args
162 ckm1 = _extract_diag_element(rho, L, -j)
163 y0, y1 = (
164 ckm1 - y1 * (k * (L + k) / ((L + k + 1) * (k + 1))) ** 0.5,
165 y0 - y1 * (L + 2 * k - x + 1) * (
166 (L + k + 1) * (k + 1)) ** -0.5,
167 )
169 return y0, y1
171 return jax.lax.cond(j >= N + 1 - L, lambda: args, body)
173 y0, y1 = jax.lax.fori_loop(3, N + 1, loop, (y0, y1))
175 return y0 - y1 * (L + 1 - x) * (L + 1) ** (-0.5)
178 return jax.lax.cond(N - L == 1, len_c_1, lambda: jax.lax.cond(N - L == 2,
179 len_c_2,
180 len_c_other))
183def qfunc(psi, xvec, yvec, g=2):
184 r"""
185 Husimi-Q function of a given state vector or density matrix at phase-space
186 points ``0.5 * g * (xvec + i*yvec)``.
188 Parameters
189 ----------
190 state : Qarray
191 A state vector or density matrix. This cannot have tensor-product
192 structure.
194 xvec, yvec : array_like
195 x- and y-coordinates at which to calculate the Husimi-Q function.
197 g : float, default: 2
198 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
199 related to the value of :math:`\hbar` in the commutation relation
200 :math:`[x,\,y] = i\hbar` via :math:`\hbar=2/g^2`.
202 Returns
203 -------
204 jnp.ndarray
205 Values representing the Husimi-Q function calculated over the specified
206 range ``[xvec, yvec]``.
208 """
210 alpha_grid, prefactor = _qfunc_coherent_grid(xvec, yvec, g)
212 if psi.is_vec():
213 psi = psi.to_ket()
215 def _compute_qfunc(psi, alpha_grid, prefactor, g):
216 out = _qfunc_iterative_single(psi, alpha_grid, prefactor, g)
217 out /= jnp.pi
218 return out
219 else:
221 def _compute_qfunc(psi, alpha_grid, prefactor, g):
222 values, vectors = jnp.linalg.eigh(psi)
223 vectors = vectors.T
224 out = values[0] * _qfunc_iterative_single(
225 vectors[0], alpha_grid, prefactor, g
226 )
227 for value, vector in zip(values[1:], vectors[1:]):
228 out += value * _qfunc_iterative_single(vector, alpha_grid, prefactor, g)
229 out /= jnp.pi
231 return out
233 psi = psi.data
235 vmapped_compute_qfunc = [_compute_qfunc]
237 for _ in psi.shape[:-2]:
238 vmapped_compute_qfunc.append(
239 vmap(
240 vmapped_compute_qfunc[-1],
241 in_axes=(0, None, None, None),
242 out_axes=0,
243 )
244 )
245 return vmapped_compute_qfunc[-1](psi, alpha_grid, prefactor, g)
248def _qfunc_iterative_single(
249 vector,
250 grid,
251 prefactor,
252 g,
253):
254 r"""
255 Get the Q function (without the :math:`\pi` scaling factor) of a single
256 state vector, using the iterative algorithm which recomputes the powers of
257 the coherent-state matrix.
258 """
259 vector = vector.squeeze()
260 ns = jnp.arange(vector.shape[-1])
261 out = jnp.polyval(
262 (vector / jnp.sqrt(factorial(ns)))[::-1],
263 grid,
264 )
265 out *= prefactor
266 return jnp.abs(out) ** 2
269def _qfunc_coherent_grid(xvec, yvec, g):
270 x, y = jnp.meshgrid(0.5 * g * xvec, 0.5 * g * yvec)
271 grid = jnp.empty(x.shape, dtype=jnp.complex128)
272 grid += x
273 # We produce the adjoint of the coherent states to save an operation
274 # later when computing dot products, hence the negative imaginary part.
275 grid += -y * 1.0j
276 prefactor = jnp.exp(-0.5 * (x * x + y * y)).astype(jnp.complex128)
277 return grid, prefactor