Coverage for jaxquantum/devices/base/base.py: 0%
144 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Base device."""
3from abc import abstractmethod, ABC
4from enum import Enum
5from typing import Dict, Any, List
7from flax import struct
8from jax import config, Array
9import jax.numpy as jnp
11from jaxquantum.core.qarray import Qarray
12from jaxquantum.core.dims import Qtypes
14config.update("jax_enable_x64", True)
17class BasisTypes(str, Enum):
18 fock = "fock"
19 charge = "charge"
20 singlecharge = "single_charge"
21 singlecharge_even = "singlecharge_even"
22 singlecharge_odd = "singlecharge_odd"
24 @classmethod
25 def from_str(cls, string: str):
26 return cls(string)
28 def __str__(self):
29 return self.value
31 def __repr__(self):
32 return self.__str__()
34 def __eq__(self, other):
35 return self.value == other.value
37 def __ne__(self, other):
38 return self.value != other.value
40 def __hash__(self):
41 return hash(self.value)
44class HamiltonianTypes(str, Enum):
45 linear = "linear"
46 truncated = "truncated"
47 full = "full"
49 @classmethod
50 def from_str(cls, string: str):
51 return cls(string)
53 def __str__(self):
54 return self.value
56 def __repr__(self):
57 return self.__str__()
59 def __eq__(self, other):
60 return self.value == other.value
62 def __ne__(self, other):
63 return self.value != other.value
65 def __hash__(self):
66 return hash(self.value)
69@struct.dataclass
70class Device(ABC):
71 DEFAULT_BASIS = BasisTypes.fock
72 DEFAULT_HAMILTONIAN = HamiltonianTypes.full
74 N: int = struct.field(pytree_node=False)
75 N_pre_diag: int = struct.field(pytree_node=False)
76 params: Dict[str, Any]
77 _label: int = struct.field(pytree_node=False)
78 _basis: BasisTypes = struct.field(pytree_node=False)
79 _hamiltonian: HamiltonianTypes = struct.field(pytree_node=False)
81 @classmethod
82 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
83 """This can be overridden by subclasses."""
84 pass
86 @classmethod
87 def create(
88 cls,
89 N,
90 params,
91 label=0,
92 N_pre_diag=None,
93 use_linear=False,
94 hamiltonian: HamiltonianTypes = None,
95 basis: BasisTypes = None,
96 ):
97 """Create a device.
99 Args:
100 N (int): dimension of Hilbert space.
101 params (dict): parameters of the device.
102 label (int, optional): label for the device. Defaults to 0. This is useful when you have multiple of the same device type in the same system.
103 N_pre_diag (int, optional): dimension of Hilbert space before diagonalization. Defaults to None, in which case it is set to N. This must be greater than or rqual to N.
104 use_linear (bool): whether to use the linearized device. Defaults to False. This will override the hamiltonian keyword argument. This is a bit redundant with hamiltonian, but it is kept for backwards compatibility.
105 hamiltonian (HamiltonianTypes, optional): type of Hamiltonian. Defaults to None, in which case the full hamiltonian is used.
106 basis (BasisTypes, optional): type of basis. Defaults to None, in which case the fock basis is used.
107 """
109 if N_pre_diag is None:
110 N_pre_diag = N
112 assert N_pre_diag >= N, "N_pre_diag must be greater than or equal to N."
114 _basis = basis if basis is not None else cls.DEFAULT_BASIS
115 _hamiltonian = (
116 hamiltonian if hamiltonian is not None else cls.DEFAULT_HAMILTONIAN
117 )
119 if use_linear:
120 _hamiltonian = HamiltonianTypes.linear
122 cls.param_validation(N, N_pre_diag, params, _hamiltonian, _basis)
124 return cls(N, N_pre_diag, params, label, _basis, _hamiltonian)
126 @property
127 def basis(self):
128 return self._basis
130 @property
131 def hamiltonian(self):
132 return self._hamiltonian
134 @property
135 def label(self):
136 return self.__class__.__name__ + str(self._label)
138 @property
139 def linear_ops(self):
140 return self.common_ops()
142 @property
143 def original_ops(self):
144 return self.common_ops()
146 @property
147 def ops(self):
148 return self.full_ops()
150 @abstractmethod
151 def common_ops(self) -> Dict[str, Qarray]:
152 """Set up common ops in the specified basis."""
154 @abstractmethod
155 def get_linear_ω(self):
156 """Get frequency of linear terms."""
158 @abstractmethod
159 def get_H_linear(self):
160 """Return linear terms in H."""
162 @abstractmethod
163 def get_H_full(self):
164 """Return full H."""
166 def get_H(self):
167 """
168 Return diagonalized H. Explicitly keep only diagonal elements of matrix.
169 """
170 return self.get_op_in_H_eigenbasis(
171 self._get_H_in_original_basis()
172 ).keep_only_diag_elements()
174 def _get_H_in_original_basis(self):
175 """This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses."""
177 if self.hamiltonian == HamiltonianTypes.linear:
178 return self.get_H_linear()
179 elif self.hamiltonian == HamiltonianTypes.full:
180 return self.get_H_full()
182 def _calculate_eig_systems(self):
183 evs, evecs = jnp.linalg.eigh(self._get_H_in_original_basis().data) # Hermitian
184 idxs_sorted = jnp.argsort(evs)
185 return evs[idxs_sorted], evecs[:, idxs_sorted]
187 @property
188 def eig_systems(self):
189 eig_systems = {}
190 eig_systems["vals"], eig_systems["vecs"] = self._calculate_eig_systems()
192 eig_systems["vecs"] = eig_systems["vecs"]
193 eig_systems["vals"] = eig_systems["vals"]
194 return eig_systems
196 def get_op_in_H_eigenbasis(self, op: Qarray):
197 evecs = self.eig_systems["vecs"][:, : self.N]
198 dims = [[self.N], [self.N]]
199 return get_op_in_new_basis(op, evecs, dims)
201 def get_op_data_in_H_eigenbasis(self, op: Array):
202 evecs = self.eig_systems["vecs"][:, : self.N]
203 return get_op_data_in_new_basis(op, evecs)
205 def get_vec_in_H_eigenbasis(self, vec: Qarray):
206 evecs = self.eig_systems["vecs"][:, : self.N]
207 if vec.qtype == Qtypes.ket:
208 dims = [[self.N], [1]]
209 else:
210 dims = [[1], [self.N]]
211 return get_vec_in_new_basis(vec, evecs, dims)
213 def get_vec_data_in_H_eigenbasis(self, vec: Array):
214 evecs = self.eig_systems["vecs"][:, : self.N]
215 return get_vec_data_in_new_basis(vec, evecs)
217 def full_ops(self):
218 # TODO: use JAX vmap here
220 linear_ops = self.linear_ops
221 ops = {}
222 for name, op in linear_ops.items():
223 ops[name] = self.get_op_in_H_eigenbasis(op)
225 return ops
228def get_op_in_new_basis(op: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray:
229 data = get_op_data_in_new_basis(op.data, evecs)
230 return Qarray.create(data, dims=dims)
233def get_op_data_in_new_basis(op_data: Array, evecs: Array) -> Array:
234 return jnp.dot(jnp.conjugate(evecs.transpose()), jnp.dot(op_data, evecs))
237def get_vec_in_new_basis(vec: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray:
238 return Qarray.create(get_vec_data_in_new_basis(vec.data, evecs), dims=dims)
241def get_vec_data_in_new_basis(vec_data: Array, evecs: Array) -> Array:
242 return jnp.dot(jnp.conjugate(evecs.transpose()), vec_data)