Coverage for jaxquantum/devices/base/base.py: 0%

142 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Base device.""" 

2 

3from abc import abstractmethod, ABC 

4from enum import Enum 

5from typing import Dict, Any, List 

6 

7from flax import struct 

8from jax import config, Array 

9import jax.numpy as jnp 

10 

11from jaxquantum.core.qarray import Qarray 

12from jaxquantum.core.dims import Qtypes 

13 

14config.update("jax_enable_x64", True) 

15 

16 

17class BasisTypes(str, Enum): 

18 fock = "fock" 

19 charge = "charge" 

20 single_charge = "single_charge" 

21 

22 @classmethod 

23 def from_str(cls, string: str): 

24 return cls(string) 

25 

26 def __str__(self): 

27 return self.value 

28 

29 def __repr__(self): 

30 return self.__str__() 

31 

32 def __eq__(self, other): 

33 return self.value == other.value 

34 

35 def __ne__(self, other): 

36 return self.value != other.value 

37 

38 def __hash__(self): 

39 return hash(self.value) 

40 

41 

42class HamiltonianTypes(str, Enum): 

43 linear = "linear" 

44 truncated = "truncated" 

45 full = "full" 

46 

47 @classmethod 

48 def from_str(cls, string: str): 

49 return cls(string) 

50 

51 def __str__(self): 

52 return self.value 

53 

54 def __repr__(self): 

55 return self.__str__() 

56 

57 def __eq__(self, other): 

58 return self.value == other.value 

59 

60 def __ne__(self, other): 

61 return self.value != other.value 

62 

63 def __hash__(self): 

64 return hash(self.value) 

65 

66 

67@struct.dataclass 

68class Device(ABC): 

69 DEFAULT_BASIS = BasisTypes.fock 

70 DEFAULT_HAMILTONIAN = HamiltonianTypes.full 

71 

72 N: int = struct.field(pytree_node=False) 

73 N_pre_diag: int = struct.field(pytree_node=False) 

74 params: Dict[str, Any] 

75 _label: int = struct.field(pytree_node=False) 

76 _basis: BasisTypes = struct.field(pytree_node=False) 

77 _hamiltonian: HamiltonianTypes = struct.field(pytree_node=False) 

78 

79 @classmethod 

80 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis): 

81 """This can be overridden by subclasses.""" 

82 pass 

83 

84 @classmethod 

85 def create( 

86 cls, 

87 N, 

88 params, 

89 label=0, 

90 N_pre_diag=None, 

91 use_linear=False, 

92 hamiltonian: HamiltonianTypes = None, 

93 basis: BasisTypes = None, 

94 ): 

95 """Create a device. 

96 

97 Args: 

98 N (int): dimension of Hilbert space. 

99 params (dict): parameters of the device. 

100 label (int, optional): label for the device. Defaults to 0. This is useful when you have multiple of the same device type in the same system. 

101 N_pre_diag (int, optional): dimension of Hilbert space before diagonalization. Defaults to None, in which case it is set to N. This must be greater than or rqual to N. 

102 use_linear (bool): whether to use the linearized device. Defaults to False. This will override the hamiltonian keyword argument. This is a bit redundant with hamiltonian, but it is kept for backwards compatibility. 

103 hamiltonian (HamiltonianTypes, optional): type of Hamiltonian. Defaults to None, in which case the full hamiltonian is used. 

104 basis (BasisTypes, optional): type of basis. Defaults to None, in which case the fock basis is used. 

105 """ 

106 

107 if N_pre_diag is None: 

108 N_pre_diag = N 

109 

110 assert N_pre_diag >= N, "N_pre_diag must be greater than or equal to N." 

111 

112 _basis = basis if basis is not None else cls.DEFAULT_BASIS 

113 _hamiltonian = ( 

114 hamiltonian if hamiltonian is not None else cls.DEFAULT_HAMILTONIAN 

115 ) 

116 

117 if use_linear: 

118 _hamiltonian = HamiltonianTypes.linear 

119 

120 cls.param_validation(N, N_pre_diag, params, _hamiltonian, _basis) 

121 

122 return cls(N, N_pre_diag, params, label, _basis, _hamiltonian) 

123 

124 @property 

125 def basis(self): 

126 return self._basis 

127 

128 @property 

129 def hamiltonian(self): 

130 return self._hamiltonian 

131 

132 @property 

133 def label(self): 

134 return self.__class__.__name__ + str(self._label) 

135 

136 @property 

137 def linear_ops(self): 

138 return self.common_ops() 

139 

140 @property 

141 def original_ops(self): 

142 return self.common_ops() 

143 

144 @property 

145 def ops(self): 

146 return self.full_ops() 

147 

148 @abstractmethod 

149 def common_ops(self) -> Dict[str, Qarray]: 

150 """Set up common ops in the specified basis.""" 

151 

152 @abstractmethod 

153 def get_linear_ω(self): 

154 """Get frequency of linear terms.""" 

155 

156 @abstractmethod 

157 def get_H_linear(self): 

158 """Return linear terms in H.""" 

159 

160 @abstractmethod 

161 def get_H_full(self): 

162 """Return full H.""" 

163 

164 def get_H(self): 

165 """ 

166 Return diagonalized H. Explicitly keep only diagonal elements of matrix. 

167 """ 

168 return self.get_op_in_H_eigenbasis( 

169 self._get_H_in_original_basis() 

170 ).keep_only_diag_elements() 

171 

172 def _get_H_in_original_basis(self): 

173 """This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses.""" 

174 

175 if self.hamiltonian == HamiltonianTypes.linear: 

176 return self.get_H_linear() 

177 elif self.hamiltonian == HamiltonianTypes.full: 

178 return self.get_H_full() 

179 

180 def _calculate_eig_systems(self): 

181 evs, evecs = jnp.linalg.eigh(self._get_H_in_original_basis().data) # Hermitian 

182 idxs_sorted = jnp.argsort(evs) 

183 return evs[idxs_sorted], evecs[:, idxs_sorted] 

184 

185 @property 

186 def eig_systems(self): 

187 eig_systems = {} 

188 eig_systems["vals"], eig_systems["vecs"] = self._calculate_eig_systems() 

189 

190 eig_systems["vecs"] = eig_systems["vecs"] 

191 eig_systems["vals"] = eig_systems["vals"] 

192 return eig_systems 

193 

194 def get_op_in_H_eigenbasis(self, op: Qarray): 

195 evecs = self.eig_systems["vecs"][:, : self.N] 

196 dims = [[self.N], [self.N]] 

197 return get_op_in_new_basis(op, evecs, dims) 

198 

199 def get_op_data_in_H_eigenbasis(self, op: Array): 

200 evecs = self.eig_systems["vecs"][:, : self.N] 

201 return get_op_data_in_new_basis(op, evecs) 

202 

203 def get_vec_in_H_eigenbasis(self, vec: Qarray): 

204 evecs = self.eig_systems["vecs"][:, : self.N] 

205 if vec.qtype == Qtypes.ket: 

206 dims = [[self.N], [1]] 

207 else: 

208 dims = [[1], [self.N]] 

209 return get_vec_in_new_basis(vec, evecs, dims) 

210 

211 def get_vec_data_in_H_eigenbasis(self, vec: Array): 

212 evecs = self.eig_systems["vecs"][:, : self.N] 

213 return get_vec_data_in_new_basis(vec, evecs) 

214 

215 def full_ops(self): 

216 # TODO: use JAX vmap here 

217 

218 linear_ops = self.linear_ops 

219 ops = {} 

220 for name, op in linear_ops.items(): 

221 ops[name] = self.get_op_in_H_eigenbasis(op) 

222 

223 return ops 

224 

225 

226def get_op_in_new_basis(op: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray: 

227 data = get_op_data_in_new_basis(op.data, evecs) 

228 return Qarray.create(data, dims=dims) 

229 

230 

231def get_op_data_in_new_basis(op_data: Array, evecs: Array) -> Array: 

232 return jnp.dot(jnp.conjugate(evecs.transpose()), jnp.dot(op_data, evecs)) 

233 

234 

235def get_vec_in_new_basis(vec: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray: 

236 return Qarray.create(get_vec_data_in_new_basis(vec.data, evecs), dims=dims) 

237 

238 

239def get_vec_data_in_new_basis(vec_data: Array, evecs: Array) -> Array: 

240 return jnp.dot(jnp.conjugate(evecs.transpose()), vec_data)