Coverage for jaxquantum/circuits/library/oscillator.py: 73%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Oscillator gates.""" 

2 

3from jaxquantum.core.operators import displace, basis, destroy, create, num 

4from jaxquantum.circuits.gates import Gate 

5from jax.scipy.special import factorial 

6import jax.numpy as jnp 

7from jaxquantum import Qarray 

8 

9 

10def D(N, alpha, ts=None, c_ops=None): 

11 

12 gen_Ht = None 

13 if ts is not None: 

14 delta_t = ts[-1] - ts[0] 

15 amp = 1j * alpha / delta_t 

16 a = destroy(N) 

17 gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag()) 

18 

19 return Gate.create( 

20 N, 

21 name="D", 

22 params={"alpha": alpha}, 

23 gen_U=lambda params: displace(N, params["alpha"]), 

24 gen_Ht=gen_Ht, 

25 ts=ts, 

26 gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops, 

27 num_modes=1, 

28 ) 

29 

30 

31def CD(N, beta, ts=None): 

32 g = basis(2, 0) 

33 e = basis(2, 1) 

34 

35 gg = g @ g.dag() 

36 ee = e @ e.dag() 

37 

38 gen_Ht = None 

39 if ts is not None: 

40 delta_t = ts[-1] - ts[0] 

41 amp = 1j * beta / delta_t / 2 

42 a = destroy(N) 

43 gen_Ht = lambda params: lambda t: ( 

44 gg ^ (jnp.conj(amp) * a + amp * a.dag()) 

45 + ee ^ (jnp.conj(-amp) * a + (-amp) * a.dag()) 

46 ) 

47 

48 return Gate.create( 

49 [2, N], 

50 name="CD", 

51 params={"beta": beta}, 

52 gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2)) 

53 + (ee ^ displace(N, -params["beta"] / 2)), 

54 gen_Ht=gen_Ht, 

55 ts=ts, 

56 num_modes=2, 

57 ) 

58 

59 

60def _Kraus_Op(N, err_prob, l): 

61 """ " Returns the Kraus Operators for l-photon loss with probability 

62 err_prob in a Hilbert Space of size N""" 

63 return ( 

64 jnp.sqrt(jnp.power(err_prob, l) / factorial(l)) 

65 * (num(N) * jnp.log(jnp.sqrt(1 - err_prob))).expm() 

66 * destroy(N).powm(l) 

67 ) 

68 

69 

70def Amp_Damp(N, err_prob, max_l): 

71 kmap = lambda params: Qarray.from_list( 

72 [_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)] 

73 ) 

74 return Gate.create( 

75 N, 

76 name="Amp_Damp", 

77 params={"err_prob": err_prob, "max_l": max_l}, 

78 gen_KM=kmap, 

79 num_modes=1, 

80 ) 

81 

82def selfKerr(N, K): 

83 a = destroy(N) 

84 return Gate.create( 

85 N, 

86 name="selfKerr", 

87 params={"Kerr": K}, 

88 gen_U=lambda params: (-1.j * K/2 * (a.dag() @ a.dag() @ a @ a)).expm(), 

89 num_modes=1, 

90 )