Coverage for jaxquantum/devices/base/base.py: 0%

144 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Base device.""" 

2 

3from abc import abstractmethod, ABC 

4from enum import Enum 

5from typing import Dict, Any, List 

6 

7from flax import struct 

8from jax import config, Array 

9import jax.numpy as jnp 

10 

11from jaxquantum.core.qarray import Qarray 

12from jaxquantum.core.dims import Qtypes 

13 

14config.update("jax_enable_x64", True) 

15 

16 

17class BasisTypes(str, Enum): 

18 fock = "fock" 

19 charge = "charge" 

20 singlecharge = "single_charge" 

21 singlecharge_even = "singlecharge_even" 

22 singlecharge_odd = "singlecharge_odd" 

23 

24 @classmethod 

25 def from_str(cls, string: str): 

26 return cls(string) 

27 

28 def __str__(self): 

29 return self.value 

30 

31 def __repr__(self): 

32 return self.__str__() 

33 

34 def __eq__(self, other): 

35 return self.value == other.value 

36 

37 def __ne__(self, other): 

38 return self.value != other.value 

39 

40 def __hash__(self): 

41 return hash(self.value) 

42 

43 

44class HamiltonianTypes(str, Enum): 

45 linear = "linear" 

46 truncated = "truncated" 

47 full = "full" 

48 

49 @classmethod 

50 def from_str(cls, string: str): 

51 return cls(string) 

52 

53 def __str__(self): 

54 return self.value 

55 

56 def __repr__(self): 

57 return self.__str__() 

58 

59 def __eq__(self, other): 

60 return self.value == other.value 

61 

62 def __ne__(self, other): 

63 return self.value != other.value 

64 

65 def __hash__(self): 

66 return hash(self.value) 

67 

68 

69@struct.dataclass 

70class Device(ABC): 

71 DEFAULT_BASIS = BasisTypes.fock 

72 DEFAULT_HAMILTONIAN = HamiltonianTypes.full 

73 

74 N: int = struct.field(pytree_node=False) 

75 N_pre_diag: int = struct.field(pytree_node=False) 

76 params: Dict[str, Any] 

77 _label: int = struct.field(pytree_node=False) 

78 _basis: BasisTypes = struct.field(pytree_node=False) 

79 _hamiltonian: HamiltonianTypes = struct.field(pytree_node=False) 

80 

81 @classmethod 

82 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis): 

83 """This can be overridden by subclasses.""" 

84 pass 

85 

86 @classmethod 

87 def create( 

88 cls, 

89 N, 

90 params, 

91 label=0, 

92 N_pre_diag=None, 

93 use_linear=False, 

94 hamiltonian: HamiltonianTypes = None, 

95 basis: BasisTypes = None, 

96 ): 

97 """Create a device. 

98 

99 Args: 

100 N (int): dimension of Hilbert space. 

101 params (dict): parameters of the device. 

102 label (int, optional): label for the device. Defaults to 0. This is useful when you have multiple of the same device type in the same system. 

103 N_pre_diag (int, optional): dimension of Hilbert space before diagonalization. Defaults to None, in which case it is set to N. This must be greater than or rqual to N. 

104 use_linear (bool): whether to use the linearized device. Defaults to False. This will override the hamiltonian keyword argument. This is a bit redundant with hamiltonian, but it is kept for backwards compatibility. 

105 hamiltonian (HamiltonianTypes, optional): type of Hamiltonian. Defaults to None, in which case the full hamiltonian is used. 

106 basis (BasisTypes, optional): type of basis. Defaults to None, in which case the fock basis is used. 

107 """ 

108 

109 if N_pre_diag is None: 

110 N_pre_diag = N 

111 

112 assert N_pre_diag >= N, "N_pre_diag must be greater than or equal to N." 

113 

114 _basis = basis if basis is not None else cls.DEFAULT_BASIS 

115 _hamiltonian = ( 

116 hamiltonian if hamiltonian is not None else cls.DEFAULT_HAMILTONIAN 

117 ) 

118 

119 if use_linear: 

120 _hamiltonian = HamiltonianTypes.linear 

121 

122 cls.param_validation(N, N_pre_diag, params, _hamiltonian, _basis) 

123 

124 return cls(N, N_pre_diag, params, label, _basis, _hamiltonian) 

125 

126 @property 

127 def basis(self): 

128 return self._basis 

129 

130 @property 

131 def hamiltonian(self): 

132 return self._hamiltonian 

133 

134 @property 

135 def label(self): 

136 return self.__class__.__name__ + str(self._label) 

137 

138 @property 

139 def linear_ops(self): 

140 return self.common_ops() 

141 

142 @property 

143 def original_ops(self): 

144 return self.common_ops() 

145 

146 @property 

147 def ops(self): 

148 return self.full_ops() 

149 

150 @abstractmethod 

151 def common_ops(self) -> Dict[str, Qarray]: 

152 """Set up common ops in the specified basis.""" 

153 

154 @abstractmethod 

155 def get_linear_ω(self): 

156 """Get frequency of linear terms.""" 

157 

158 @abstractmethod 

159 def get_H_linear(self): 

160 """Return linear terms in H.""" 

161 

162 @abstractmethod 

163 def get_H_full(self): 

164 """Return full H.""" 

165 

166 def get_H(self): 

167 """ 

168 Return diagonalized H. Explicitly keep only diagonal elements of matrix. 

169 """ 

170 return self.get_op_in_H_eigenbasis( 

171 self._get_H_in_original_basis() 

172 ).keep_only_diag_elements() 

173 

174 def _get_H_in_original_basis(self): 

175 """This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses.""" 

176 

177 if self.hamiltonian == HamiltonianTypes.linear: 

178 return self.get_H_linear() 

179 elif self.hamiltonian == HamiltonianTypes.full: 

180 return self.get_H_full() 

181 

182 def _calculate_eig_systems(self): 

183 evs, evecs = jnp.linalg.eigh(self._get_H_in_original_basis().data) # Hermitian 

184 idxs_sorted = jnp.argsort(evs) 

185 return evs[idxs_sorted], evecs[:, idxs_sorted] 

186 

187 @property 

188 def eig_systems(self): 

189 eig_systems = {} 

190 eig_systems["vals"], eig_systems["vecs"] = self._calculate_eig_systems() 

191 

192 eig_systems["vecs"] = eig_systems["vecs"] 

193 eig_systems["vals"] = eig_systems["vals"] 

194 return eig_systems 

195 

196 def get_op_in_H_eigenbasis(self, op: Qarray): 

197 evecs = self.eig_systems["vecs"][:, : self.N] 

198 dims = [[self.N], [self.N]] 

199 return get_op_in_new_basis(op, evecs, dims) 

200 

201 def get_op_data_in_H_eigenbasis(self, op: Array): 

202 evecs = self.eig_systems["vecs"][:, : self.N] 

203 return get_op_data_in_new_basis(op, evecs) 

204 

205 def get_vec_in_H_eigenbasis(self, vec: Qarray): 

206 evecs = self.eig_systems["vecs"][:, : self.N] 

207 if vec.qtype == Qtypes.ket: 

208 dims = [[self.N], [1]] 

209 else: 

210 dims = [[1], [self.N]] 

211 return get_vec_in_new_basis(vec, evecs, dims) 

212 

213 def get_vec_data_in_H_eigenbasis(self, vec: Array): 

214 evecs = self.eig_systems["vecs"][:, : self.N] 

215 return get_vec_data_in_new_basis(vec, evecs) 

216 

217 def full_ops(self): 

218 # TODO: use JAX vmap here 

219 

220 linear_ops = self.linear_ops 

221 ops = {} 

222 for name, op in linear_ops.items(): 

223 ops[name] = self.get_op_in_H_eigenbasis(op) 

224 

225 return ops 

226 

227 

228def get_op_in_new_basis(op: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray: 

229 data = get_op_data_in_new_basis(op.data, evecs) 

230 return Qarray.create(data, dims=dims) 

231 

232 

233def get_op_data_in_new_basis(op_data: Array, evecs: Array) -> Array: 

234 return jnp.dot(jnp.conjugate(evecs.transpose()), jnp.dot(op_data, evecs)) 

235 

236 

237def get_vec_in_new_basis(vec: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray: 

238 return Qarray.create(get_vec_data_in_new_basis(vec.data, evecs), dims=dims) 

239 

240 

241def get_vec_data_in_new_basis(vec_data: Array, evecs: Array) -> Array: 

242 return jnp.dot(jnp.conjugate(evecs.transpose()), vec_data)