Coverage for jaxquantum/devices/superconducting/ats.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""ATS.""" 

2 

3from flax import struct 

4from jax import config 

5 

6import jax.numpy as jnp 

7 

8from jaxquantum.devices.superconducting.flux_base import FluxDevice 

9from jaxquantum.core.operators import identity, destroy, create 

10from jaxquantum.core.qarray import cosm, sinm 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15@struct.dataclass 

16class ATS(FluxDevice): 

17 """ 

18 ATS Device. 

19 """ 

20 

21 def common_ops(self): 

22 """Written in the linear basis.""" 

23 ops = {} 

24 

25 N = self.N_pre_diag 

26 ops["id"] = identity(N) 

27 ops["a"] = destroy(N) 

28 ops["a_dag"] = create(N) 

29 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"]) 

30 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"]) 

31 return ops 

32 

33 def phi_zpf(self): 

34 """Return Phase ZPF.""" 

35 return (2 * self.params["Ec"] / self.params["El"]) ** (0.25) 

36 

37 def n_zpf(self): 

38 """Return Charge ZPF.""" 

39 return (self.params["El"] / (32 * self.params["Ec"])) ** (0.25) 

40 

41 def get_linear_ω(self): 

42 """Get frequency of linear terms.""" 

43 return jnp.sqrt(8 * self.params["El"] * self.params["Ec"]) 

44 

45 def get_H_linear(self): 

46 """Return linear terms in H.""" 

47 w = self.get_linear_ω() 

48 return w * ( 

49 self.linear_ops["a_dag"] @ self.linear_ops["a"] 

50 + 0.5 * self.linear_ops["id"] 

51 ) 

52 

53 @staticmethod 

54 def get_H_nonlinear_static(phi_op, Ej, dEj, Ej2, phi_sum, phi_delta): 

55 cos_phi_op = cosm(phi_op) 

56 sin_phi_op = sinm(phi_op) 

57 

58 cos_2phi_op = cos_phi_op @ cos_phi_op - sin_phi_op @ sin_phi_op 

59 sin_2phi_op = 2 * cos_phi_op @ sin_phi_op 

60 

61 H_nl_Ej = ( 

62 -2 

63 * Ej 

64 * ( 

65 cos_phi_op * jnp.cos(2 * jnp.pi * phi_delta) 

66 - sin_phi_op * jnp.sin(2 * jnp.pi * phi_delta) 

67 ) 

68 * jnp.cos(2 * jnp.pi * phi_sum) 

69 ) 

70 H_nl_dEj = ( 

71 2 

72 * dEj 

73 * ( 

74 sin_phi_op * jnp.cos(2 * jnp.pi * phi_delta) 

75 + cos_phi_op * jnp.sin(2 * jnp.pi * phi_delta) 

76 ) 

77 * jnp.sin(2 * jnp.pi * phi_sum) 

78 ) 

79 H_nl_Ej2 = ( 

80 2 

81 * Ej2 

82 * ( 

83 cos_2phi_op * jnp.cos(2 * 2 * jnp.pi * phi_delta) 

84 - sin_2phi_op * jnp.sin(2 * 2 * jnp.pi * phi_delta) 

85 ) 

86 * jnp.cos(2 * 2 * jnp.pi * phi_sum) 

87 ) 

88 

89 H_nl = H_nl_Ej + H_nl_dEj + H_nl_Ej2 

90 

91 # id_op = jqt.identity_like(phi_op) 

92 # phi_delta_ext_op = self.params["phi_delta_ext"] * id_op 

93 # H_nl_old = - 2 * Ej * jqt.cosm(phi_op + 2 * jnp.pi * phi_delta_ext_op) * jnp.cos(2 * jnp.pi * self.params["phi_sum_ext"]) 

94 # H_nl_old += 2 * dEj * jqt.sinm(phi_op + 2 * jnp.pi * phi_delta_ext_op) * jnp.sin(2 * jnp.pi * self.params["phi_sum_ext"]) 

95 # H_nl_old += 2 * Ej2 * jqt.cosm(2*phi_op + 2 * 2 * jnp.pi * phi_delta_ext_op) * jnp.cos(2 * 2 * jnp.pi * self.params["phi_sum_ext"]) 

96 

97 return H_nl 

98 

99 def get_H_nonlinear(self, phi_op): 

100 """Return nonlinear terms in H.""" 

101 

102 Ej = self.params["Ej"] 

103 dEj = self.params["dEj"] 

104 Ej2 = self.params["Ej2"] 

105 

106 phi_sum = self.params["phi_sum_ext"] 

107 phi_delta = self.params["phi_delta_ext"] 

108 

109 return ATS.get_H_nonlinear_static(phi_op, Ej, dEj, Ej2, phi_sum, phi_delta) 

110 

111 def get_H_full(self): 

112 """Return full H in linear basis.""" 

113 phi_b = self.linear_ops["phi"] 

114 H_nl = self.get_H_nonlinear(phi_b) 

115 H = self.get_H_linear() + H_nl 

116 return H 

117 

118 def potential(self, phi): 

119 """Return potential energy for a given phi.""" 

120 

121 phi_delta_ext = self.params["phi_delta_ext"] 

122 phi_sum_ext = self.params["phi_sum_ext"] 

123 

124 V = 0.5 * self.params["El"] * (2 * jnp.pi * phi) ** 2 

125 V += ( 

126 -2 

127 * self.params["Ej"] 

128 * jnp.cos(2 * jnp.pi * (phi + phi_delta_ext)) 

129 * jnp.cos(2 * jnp.pi * phi_sum_ext) 

130 ) 

131 V += ( 

132 2 

133 * self.params["dEj"] 

134 * jnp.sin(2 * jnp.pi * (phi + phi_delta_ext)) 

135 * jnp.sin(2 * jnp.pi * phi_sum_ext) 

136 ) 

137 V += ( 

138 2 

139 * self.params["Ej2"] 

140 * jnp.cos(2 * 2 * jnp.pi * (phi + phi_delta_ext)) 

141 * jnp.cos(2 * 2 * jnp.pi * phi_sum_ext) 

142 ) 

143 

144 return V