Coverage for jaxquantum/devices/superconducting/transmon.py: 0%

128 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Transmon.""" 

2 

3from flax import struct 

4from jax import config 

5 

6import jax.numpy as jnp 

7 

8from jaxquantum.devices.base.base import BasisTypes, HamiltonianTypes 

9from jaxquantum.devices.superconducting.flux_base import FluxDevice 

10from jaxquantum.core.operators import identity, destroy, create 

11from jaxquantum.core.conversions import jnp2jqt 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16@struct.dataclass 

17class Transmon(FluxDevice): 

18 """ 

19 Transmon Device. 

20 """ 

21 

22 DEFAULT_BASIS = BasisTypes.charge 

23 DEFAULT_HAMILTONIAN = HamiltonianTypes.full 

24 

25 @classmethod 

26 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis): 

27 """This can be overridden by subclasses.""" 

28 if hamiltonian == HamiltonianTypes.linear: 

29 assert basis == BasisTypes.fock, "Linear Hamiltonian only works with Fock basis." 

30 elif hamiltonian == HamiltonianTypes.truncated: 

31 assert basis == BasisTypes.fock, "Truncated Hamiltonian only works with Fock basis." 

32 elif hamiltonian == HamiltonianTypes.full: 

33 charge_basis_types = [ 

34 BasisTypes.charge, 

35 BasisTypes.singlecharge, 

36 BasisTypes.singlecharge_even, 

37 BasisTypes.singlecharge_odd, 

38 ] 

39 assert basis in charge_basis_types, "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases." 

40 

41 # Set the gate offset charge to zero if not provided 

42 if "ng" not in params: 

43 params["ng"] = 0.0 

44 

45 if basis in [BasisTypes.singlecharge, BasisTypes.singlecharge_even, BasisTypes.singlecharge_odd]: 

46 assert (N_pre_diag) % 2 == 0, "N_pre_diag must be even for single charge bases." 

47 else: 

48 assert (N_pre_diag - 1) % 2 == 0, "N_pre_diag must be odd." 

49 

50 def common_ops(self): 

51 """ Written in the specified basis. """ 

52 

53 ops = {} 

54 

55 N = self.N_pre_diag 

56 

57 if self.basis == BasisTypes.fock: 

58 ops["id"] = identity(N) 

59 ops["a"] = destroy(N) 

60 ops["a_dag"] = create(N) 

61 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"]) 

62 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"]) 

63 

64 elif self.basis == BasisTypes.charge: 

65 """ 

66 Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis.  

67 """ 

68 ops["id"] = identity(N) 

69 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1))) 

70 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1))) 

71 ops["cos(2φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=2) + jnp.eye(N, k=-2))) 

72 ops["sin(2φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=2) - jnp.eye(N, k=-2))) 

73 

74 n_max = (N - 1) // 2 

75 n_array = jnp.arange(-n_max, n_max + 1) 

76 ops["n"] = jnp2jqt(jnp.diag(n_array)) 

77 n_minus_ng_array = n_array - self.params["ng"] * jnp.ones(N) 

78 ops["H_charge"] = jnp2jqt(jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2)) 

79 

80 elif self.basis in [BasisTypes.singlecharge_even, BasisTypes.singlecharge_odd]: 

81 n_max = N 

82 

83 if self.basis == BasisTypes.singlecharge_even: 

84 n_array = jnp.arange(-n_max, n_max, 2) 

85 elif self.basis == BasisTypes.singlecharge_odd: 

86 n_array = jnp.arange(-n_max + 1, n_max, 2) 

87 

88 ops["id"] = identity(n_max) 

89 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(n_max, k=1) + jnp.eye(n_max, k=-1))) 

90 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(n_max, k=1) - jnp.eye(n_max, k=-1))) 

91 ops["cos(2φ)"] = 0.5 * (jnp2jqt(jnp.eye(n_max, k=2) + jnp.eye(n_max, k=-2))) 

92 ops["sin(2φ)"] = 0.5j * (jnp2jqt(jnp.eye(n_max, k=2) - jnp.eye(n_max, k=-2))) 

93 

94 ops["n"] = jnp2jqt(jnp.diag(n_array)) 

95 n_minus_ng_array = n_array - 2 * self.params["ng"] * jnp.ones(n_max) 

96 ops["H_charge"] = jnp2jqt(jnp.diag(self.params["Ec"] * n_minus_ng_array**2)) 

97 

98 elif self.basis == BasisTypes.singlecharge: 

99 """ 

100 Here H = Ec (n - 2ng)² - Ej cos(φ) in the single-electron charge basis. Using Eq. (5.36) of Kyle Serniak's 

101 thesis, we have H = Ec ∑ₙ(n - 2*ng) |n⟩⟨n| - Ej/2 * ∑ₙ|n⟩⟨n+2| + h.c where n counts the number of electrons,  

102 not Cooper pairs. Note, we use 2ng instead of ng to match the gate offset charge convention of the transmon  

103 (as done in Kyle's thesis). 

104 """ 

105 n_max = (N) // 2 

106 

107 ops["id"] = identity(N) 

108 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=2) + jnp.eye(N, k=-2))) 

109 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=2) - jnp.eye(N, k=-2))) 

110 ops["cos(φ/2)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1))) 

111 ops["sin(φ/2)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1))) 

112 

113 n_array = jnp.arange(-n_max, n_max) 

114 ops["n"] = jnp2jqt(jnp.diag(n_array)) 

115 n_minus_ng_array = n_array - 2 * self.params["ng"] * jnp.ones(N) 

116 ops["H_charge"] = jnp2jqt(jnp.diag(self.params["Ec"] * n_minus_ng_array**2)) 

117 

118 return ops 

119 

120 @property 

121 def Ej(self): 

122 return self.params["Ej"] 

123 

124 def phi_zpf(self): 

125 """Return Phase ZPF.""" 

126 return (2 * self.params["Ec"] / self.Ej) ** (0.25) 

127 

128 def n_zpf(self): 

129 """Return Charge ZPF.""" 

130 return (self.Ej / (32 * self.params["Ec"])) ** (0.25) 

131 

132 def get_linear_ω(self): 

133 """Get frequency of linear terms.""" 

134 return jnp.sqrt(8 * self.params["Ec"] * self.Ej) 

135 

136 def get_H_linear(self): 

137 """Return linear terms in H.""" 

138 w = self.get_linear_ω() 

139 return w * self.original_ops["a_dag"] @ self.original_ops["a"] 

140 

141 def get_H_full(self): 

142 """Return full H in specified basis.""" 

143 return self.original_ops["H_charge"] - self.Ej * self.original_ops["cos(φ)"] 

144 

145 def get_H_truncated(self): 

146 """Return truncated H in specified basis.""" 

147 phi_op = self.original_ops["phi"] 

148 fourth_order_term = -(1 / 24) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op 

149 sixth_order_term = (1 / 720) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op @ phi_op @ phi_op 

150 return self.get_H_linear() + fourth_order_term + sixth_order_term 

151 

152 def _get_H_in_original_basis(self): 

153 """ This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses.""" 

154 

155 if self.hamiltonian == HamiltonianTypes.linear: 

156 return self.get_H_linear() 

157 elif self.hamiltonian == HamiltonianTypes.full: 

158 return self.get_H_full() 

159 elif self.hamiltonian == HamiltonianTypes.truncated: 

160 return self.get_H_truncated() 

161 

162 def potential(self, phi): 

163 """Return potential energy for a given phi.""" 

164 if self.hamiltonian == HamiltonianTypes.linear: 

165 return 0.5 * self.Ej * (2 * jnp.pi * phi) ** 2 

166 elif self.hamiltonian == HamiltonianTypes.full: 

167 return - self.Ej * jnp.cos(2 * jnp.pi * phi) 

168 elif self.hamiltonian == HamiltonianTypes.truncated: 

169 phi_scaled = 2 * jnp.pi * phi 

170 second_order = 0.5 * self.Ej * phi_scaled ** 2 

171 fourth_order = -(1 / 24) * self.Ej * phi_scaled ** 4 

172 sixth_order = (1 / 720) * self.Ej * phi_scaled ** 6 

173 return second_order + fourth_order + sixth_order 

174 

175 def calculate_wavefunctions(self, phi_vals): 

176 """Calculate wavefunctions at phi_exts. 

177  

178 TODO: this is not currently being used for plotting... needs to be updated! 

179 """ 

180 

181 if self.basis == BasisTypes.fock: 

182 return super().calculate_wavefunctions(phi_vals) 

183 elif self.basis == BasisTypes.singlecharge: 

184 raise NotImplementedError("Wavefunctions for single charge basis not yet implemented.") 

185 elif self.basis in [BasisTypes.charge, BasisTypes.singlecharge_even, BasisTypes.singlecharge_odd]: 

186 phi_vals = jnp.array(phi_vals) 

187 

188 if self.basis in [BasisTypes.singlecharge_even, BasisTypes.singlecharge_odd]: 

189 n_labels = 1/2 * jnp.diag(self.original_ops["n"].data) 

190 else: 

191 n_labels = jnp.diag(self.original_ops["n"].data) 

192 

193 wavefunctions = [] 

194 for nj in range(self.N_pre_diag): 

195 wavefunction = [] 

196 for phi in phi_vals: 

197 wavefunction.append( 

198 (1j ** nj / jnp.sqrt(2*jnp.pi)) * jnp.sum( 

199 self.eig_systems["vecs"][:,nj] * jnp.exp(1j * phi * n_labels) 

200 ) 

201 ) 

202 wavefunctions.append(jnp.array(wavefunction)) 

203 return jnp.array(wavefunctions)