Coverage for jaxquantum/core/cfunctions.py: 100%
14 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1import jax.numpy as jnp
2from jax import vmap
3from jax.scipy.special import factorial
4import jaxquantum as jqt
7def cf_wigner(psi, xvec, yvec):
8 """Wigner function for a state vector or density matrix at points
9 `xvec + i * yvec`.
11 Parameters
12 ----------
14 state : Qarray
15 A state vector or density matrix.
17 xvec : array_like
18 x-coordinates at which to calculate the Wigner function.
20 yvec : array_like
21 y-coordinates at which to calculate the Wigner function.
24 Returns
25 -------
27 W : array
28 Values representing the Wigner function calculated over the specified
29 range [xvec,yvec].
32 """
33 N = psi.dims[0][0]
34 x, y = jnp.meshgrid(xvec, yvec)
35 alpha = x + 1.0j * y
36 displacement = jqt.displace(N, alpha)
38 vmapped_overlap = [vmap(vmap(jqt.overlap, in_axes=(None, 0)), in_axes=(
39 None, 0))]
40 for _ in psi.bdims:
41 vmapped_overlap.append(vmap(vmapped_overlap[-1], in_axes=(0, None)))
43 cf = vmapped_overlap[-1](psi, displacement)
44 return cf