Coverage for jaxquantum / core / cfunctions.py: 100%
13 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1import jax.numpy as jnp
2from jax import vmap
3import jaxquantum as jqt
6def cf_wigner(psi, xvec, yvec):
7 """Wigner function for a state vector or density matrix at points
8 `xvec + i * yvec`.
10 Parameters
11 ----------
13 state : Qarray
14 A state vector or density matrix.
16 xvec : array_like
17 x-coordinates at which to calculate the Wigner function.
19 yvec : array_like
20 y-coordinates at which to calculate the Wigner function.
23 Returns
24 -------
26 W : array
27 Values representing the Wigner function calculated over the specified
28 range [xvec,yvec].
31 """
32 N = psi.dims[0][0]
33 x, y = jnp.meshgrid(xvec, yvec)
34 alpha = x + 1.0j * y
35 displacement = jqt.displace(N, alpha)
37 vmapped_overlap = [vmap(vmap(jqt.overlap, in_axes=(None, 0)), in_axes=(
38 None, 0))]
39 for _ in psi.bdims:
40 vmapped_overlap.append(vmap(vmapped_overlap[-1], in_axes=(0, None)))
42 cf = vmapped_overlap[-1](psi, displacement)
43 return cf