Coverage for jaxquantum/circuits/library/oscillator.py: 0%

18 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Oscillator gates.""" 

2 

3from jaxquantum.core.operators import displace, basis, destroy, num 

4from jaxquantum.circuits.gates import Gate 

5from jax.scipy.special import factorial 

6import jax.numpy as jnp 

7from jaxquantum import Qarray 

8 

9 

10def D(N, alpha): 

11 return Gate.create( 

12 N, 

13 name="D", 

14 params={"alpha": alpha}, 

15 gen_U=lambda params: displace(N, params["alpha"]), 

16 num_modes=1, 

17 ) 

18 

19 

20def CD(N, beta): 

21 g = basis(2, 0) 

22 e = basis(2, 1) 

23 

24 gg = g @ g.dag() 

25 ee = e @ e.dag() 

26 

27 return Gate.create( 

28 [2, N], 

29 name="CD", 

30 params={"beta": beta}, 

31 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2)) 

32 + (ee ^ displace(N, -params["beta"] / 2)), 

33 num_modes=2, 

34 ) 

35 

36 

37def _Kraus_Op(N, err_prob, l): 

38 """ " Returns the Kraus Operators for l-photon loss with probability 

39 err_prob in a Hilbert Space of size N""" 

40 return ( 

41 jnp.sqrt(jnp.power(err_prob, l) / factorial(l)) 

42 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm() 

43 * destroy(N).powm(l) 

44 ) 

45 

46 

47def Amp_Damp(N, err_prob, max_l): 

48 kmap = lambda params: Qarray.from_list( 

49 [_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)] 

50 ) 

51 return Gate.create( 

52 N, 

53 name="Amp_Damp", 

54 params={"err_prob": err_prob, "max_l": max_l}, 

55 gen_KM=kmap, 

56 num_modes=1, 

57 )