Coverage for jaxquantum/devices/superconducting/snail.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Transmon.""" 

2 

3from flax import struct 

4from jax import config 

5 

6import jax.numpy as jnp 

7 

8from jaxquantum.devices.base.base import BasisTypes, HamiltonianTypes 

9from jaxquantum.devices.superconducting.flux_base import FluxDevice 

10from jaxquantum.core.operators import identity, destroy, create 

11from jaxquantum.core.conversions import jnp2jqt 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16@struct.dataclass 

17class SNAIL(FluxDevice): 

18 """ 

19 SNAIL Device. 

20 """ 

21 

22 DEFAULT_BASIS = BasisTypes.charge 

23 DEFAULT_HAMILTONIAN = HamiltonianTypes.full 

24 

25 @classmethod 

26 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis): 

27 """This can be overridden by subclasses.""" 

28 

29 assert params["m"] % 1 == 0, "m must be an integer." 

30 assert params["m"] >= 2, "m must be greater than or equal to 2." 

31 

32 if hamiltonian == HamiltonianTypes.linear: 

33 assert basis == BasisTypes.fock, "Linear Hamiltonian only works with Fock basis." 

34 elif hamiltonian == HamiltonianTypes.truncated: 

35 assert basis == BasisTypes.fock, "Truncated Hamiltonian only works with Fock basis." 

36 elif hamiltonian == HamiltonianTypes.full: 

37 charge_basis_types = [ 

38 BasisTypes.charge 

39 ] 

40 assert basis in charge_basis_types, "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases." 

41 

42 assert (N_pre_diag - 1) % 2 * (params["m"]) == 0, "(N_pre_diag - 1)/2 must be divisible by m." 

43 

44 # Set the gate offset charge to zero if not provided 

45 if "ng" not in params: 

46 params["ng"] = 0.0 

47 

48 def common_ops(self): 

49 """ Written in the specified basis. """ 

50 

51 ops = {} 

52 

53 N = self.N_pre_diag 

54 

55 if self.basis == BasisTypes.fock: 

56 ops["id"] = identity(N) 

57 ops["a"] = destroy(N) 

58 ops["a_dag"] = create(N) 

59 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"]) 

60 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"]) 

61 

62 elif self.basis == BasisTypes.charge: 

63 """ 

64 Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis.  

65 """ 

66 m = self.params["m"] 

67 ops["id"] = identity(N) 

68 ops["cos(φ/m)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1))) 

69 ops["sin(φ/m)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1))) 

70 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=m) + jnp.eye(N, k=-m))) 

71 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=m) - jnp.eye(N, k=-m))) 

72 

73 n_max = (N - 1) // 2 

74 n_array = jnp.arange(-n_max, n_max + 1) / self.params["m"] 

75 ops["n"] = jnp2jqt(jnp.diag(n_array)) 

76 

77 n_minus_ng_array = n_array - self.params["ng"] * jnp.ones(N) 

78 ops["H_charge"] = jnp2jqt(jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2)) 

79 

80 return ops 

81 

82 @property 

83 def Ej(self): 

84 return self.params["Ej"] 

85 

86 def phi_zpf(self): 

87 """Return Phase ZPF.""" 

88 return (2 * self.params["Ec"] / self.Ej) ** (0.25) 

89 

90 def n_zpf(self): 

91 """Return Charge ZPF.""" 

92 return (self.Ej / (32 * self.params["Ec"])) ** (0.25) 

93 

94 def get_linear_ω(self): 

95 """Get frequency of linear terms.""" 

96 return jnp.sqrt(8 * self.params["Ec"] * self.Ej) 

97 

98 def get_H_linear(self): 

99 """Return linear terms in H.""" 

100 w = self.get_linear_ω() 

101 return w * self.original_ops["a_dag"] @ self.original_ops["a"] 

102 

103 def get_H_full(self): 

104 """Return full H in specified basis.""" 

105 

106 α = self.params["alpha"] 

107 m = self.params["m"] 

108 phi_ext = self.params["phi_ext"] 

109 Ej = self.Ej 

110 

111 H_charge = self.original_ops["H_charge"] 

112 H_inductive = - α * Ej * self.original_ops["cos(φ)"] - m * Ej * ( 

113 jnp.cos(2 * jnp.pi * phi_ext/m) * self.original_ops["cos(φ/m)"] + jnp.sin(2 * jnp.pi * phi_ext/m) * self.original_ops["sin(φ/m)"] 

114 ) 

115 return H_charge + H_inductive 

116 

117 def get_H_truncated(self): 

118 """Return truncated H in specified basis.""" 

119 raise NotImplementedError("Truncated Hamiltonian not implemented for SNAIL.") 

120 # phi_op = self.original_ops["phi"]  

121 # fourth_order_term = -(1 / 24) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op  

122 # sixth_order_term = (1 / 720) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op @ phi_op @ phi_op 

123 # return self.get_H_linear() + fourth_order_term + sixth_order_term 

124 

125 def _get_H_in_original_basis(self): 

126 """ This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses.""" 

127 

128 if self.hamiltonian == HamiltonianTypes.linear: 

129 return self.get_H_linear() 

130 elif self.hamiltonian == HamiltonianTypes.full: 

131 return self.get_H_full() 

132 elif self.hamiltonian == HamiltonianTypes.truncated: 

133 return self.get_H_truncated() 

134 

135 def potential(self, phi): 

136 """Return potential energy for a given phi.""" 

137 if self.hamiltonian == HamiltonianTypes.linear: 

138 return 0.5 * self.Ej * (2 * jnp.pi * phi) ** 2 

139 elif self.hamiltonian == HamiltonianTypes.full: 

140 

141 α = self.params["alpha"] 

142 m = self.params["m"] 

143 phi_ext = self.params["phi_ext"] 

144 

145 return - α * self.Ej * jnp.cos(2 * jnp.pi * phi) - ( 

146 m * self.Ej * jnp.cos(2 * jnp.pi * (phi_ext - phi) / m) 

147 ) 

148 

149 elif self.hamiltonian == HamiltonianTypes.truncated: 

150 raise NotImplementedError("Truncated potential not implemented for SNAIL.") 

151 # phi_scaled = 2 * jnp.pi * phi 

152 # second_order = 0.5 * self.Ej * phi_scaled ** 2 

153 # fourth_order = -(1 / 24) * self.Ej * phi_scaled ** 4 

154 # sixth_order = (1 / 720) * self.Ej * phi_scaled ** 6 

155 # return second_order + fourth_order + sixth_order