Coverage for jaxquantum/core/cfunctions.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1import jax.numpy as jnp 

2from jax import vmap 

3from jax.scipy.special import factorial 

4import jaxquantum as jqt 

5 

6 

7def cf_wigner(psi, xvec, yvec): 

8 """Wigner function for a state vector or density matrix at points 

9 `xvec + i * yvec`. 

10 

11 Parameters 

12 ---------- 

13 

14 state : Qarray 

15 A state vector or density matrix. 

16 

17 xvec : array_like 

18 x-coordinates at which to calculate the Wigner function. 

19 

20 yvec : array_like 

21 y-coordinates at which to calculate the Wigner function. 

22 

23 

24 Returns 

25 ------- 

26 

27 W : array 

28 Values representing the Wigner function calculated over the specified 

29 range [xvec,yvec]. 

30 

31 

32 """ 

33 N = psi.dims[0][0] 

34 x, y = jnp.meshgrid(xvec, yvec) 

35 alpha = x + 1.0j * y 

36 displacement = jqt.displace(N, alpha) 

37 

38 vmapped_overlap = [vmap(vmap(jqt.overlap, in_axes=(None, 0)), in_axes=( 

39 None, 0))] 

40 for _ in psi.bdims: 

41 vmapped_overlap.append(vmap(vmapped_overlap[-1], in_axes=(0, None))) 

42 

43 cf = vmapped_overlap[-1](psi, displacement) 

44 return cf