Coverage for jaxquantum / core / cfunctions.py: 100%

13 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1import jax.numpy as jnp 

2from jax import vmap 

3import jaxquantum as jqt 

4 

5 

6def cf_wigner(psi, xvec, yvec): 

7 """Wigner function for a state vector or density matrix at points 

8 `xvec + i * yvec`. 

9 

10 Parameters 

11 ---------- 

12 

13 state : Qarray 

14 A state vector or density matrix. 

15 

16 xvec : array_like 

17 x-coordinates at which to calculate the Wigner function. 

18 

19 yvec : array_like 

20 y-coordinates at which to calculate the Wigner function. 

21 

22 

23 Returns 

24 ------- 

25 

26 W : array 

27 Values representing the Wigner function calculated over the specified 

28 range [xvec,yvec]. 

29 

30 

31 """ 

32 N = psi.dims[0][0] 

33 x, y = jnp.meshgrid(xvec, yvec) 

34 alpha = x + 1.0j * y 

35 displacement = jqt.displace(N, alpha) 

36 

37 vmapped_overlap = [vmap(vmap(jqt.overlap, in_axes=(None, 0)), in_axes=( 

38 None, 0))] 

39 for _ in psi.bdims: 

40 vmapped_overlap.append(vmap(vmapped_overlap[-1], in_axes=(0, None))) 

41 

42 cf = vmapped_overlap[-1](psi, displacement) 

43 return cf