Coverage for jaxquantum/devices/base/base.py: 0%
142 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Base device."""
3from abc import abstractmethod, ABC
4from enum import Enum
5from typing import Dict, Any, List
7from flax import struct
8from jax import config, Array
9import jax.numpy as jnp
11from jaxquantum.core.qarray import Qarray
12from jaxquantum.core.dims import Qtypes
14config.update("jax_enable_x64", True)
17class BasisTypes(str, Enum):
18 fock = "fock"
19 charge = "charge"
20 single_charge = "single_charge"
22 @classmethod
23 def from_str(cls, string: str):
24 return cls(string)
26 def __str__(self):
27 return self.value
29 def __repr__(self):
30 return self.__str__()
32 def __eq__(self, other):
33 return self.value == other.value
35 def __ne__(self, other):
36 return self.value != other.value
38 def __hash__(self):
39 return hash(self.value)
42class HamiltonianTypes(str, Enum):
43 linear = "linear"
44 truncated = "truncated"
45 full = "full"
47 @classmethod
48 def from_str(cls, string: str):
49 return cls(string)
51 def __str__(self):
52 return self.value
54 def __repr__(self):
55 return self.__str__()
57 def __eq__(self, other):
58 return self.value == other.value
60 def __ne__(self, other):
61 return self.value != other.value
63 def __hash__(self):
64 return hash(self.value)
67@struct.dataclass
68class Device(ABC):
69 DEFAULT_BASIS = BasisTypes.fock
70 DEFAULT_HAMILTONIAN = HamiltonianTypes.full
72 N: int = struct.field(pytree_node=False)
73 N_pre_diag: int = struct.field(pytree_node=False)
74 params: Dict[str, Any]
75 _label: int = struct.field(pytree_node=False)
76 _basis: BasisTypes = struct.field(pytree_node=False)
77 _hamiltonian: HamiltonianTypes = struct.field(pytree_node=False)
79 @classmethod
80 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
81 """This can be overridden by subclasses."""
82 pass
84 @classmethod
85 def create(
86 cls,
87 N,
88 params,
89 label=0,
90 N_pre_diag=None,
91 use_linear=False,
92 hamiltonian: HamiltonianTypes = None,
93 basis: BasisTypes = None,
94 ):
95 """Create a device.
97 Args:
98 N (int): dimension of Hilbert space.
99 params (dict): parameters of the device.
100 label (int, optional): label for the device. Defaults to 0. This is useful when you have multiple of the same device type in the same system.
101 N_pre_diag (int, optional): dimension of Hilbert space before diagonalization. Defaults to None, in which case it is set to N. This must be greater than or rqual to N.
102 use_linear (bool): whether to use the linearized device. Defaults to False. This will override the hamiltonian keyword argument. This is a bit redundant with hamiltonian, but it is kept for backwards compatibility.
103 hamiltonian (HamiltonianTypes, optional): type of Hamiltonian. Defaults to None, in which case the full hamiltonian is used.
104 basis (BasisTypes, optional): type of basis. Defaults to None, in which case the fock basis is used.
105 """
107 if N_pre_diag is None:
108 N_pre_diag = N
110 assert N_pre_diag >= N, "N_pre_diag must be greater than or equal to N."
112 _basis = basis if basis is not None else cls.DEFAULT_BASIS
113 _hamiltonian = (
114 hamiltonian if hamiltonian is not None else cls.DEFAULT_HAMILTONIAN
115 )
117 if use_linear:
118 _hamiltonian = HamiltonianTypes.linear
120 cls.param_validation(N, N_pre_diag, params, _hamiltonian, _basis)
122 return cls(N, N_pre_diag, params, label, _basis, _hamiltonian)
124 @property
125 def basis(self):
126 return self._basis
128 @property
129 def hamiltonian(self):
130 return self._hamiltonian
132 @property
133 def label(self):
134 return self.__class__.__name__ + str(self._label)
136 @property
137 def linear_ops(self):
138 return self.common_ops()
140 @property
141 def original_ops(self):
142 return self.common_ops()
144 @property
145 def ops(self):
146 return self.full_ops()
148 @abstractmethod
149 def common_ops(self) -> Dict[str, Qarray]:
150 """Set up common ops in the specified basis."""
152 @abstractmethod
153 def get_linear_ω(self):
154 """Get frequency of linear terms."""
156 @abstractmethod
157 def get_H_linear(self):
158 """Return linear terms in H."""
160 @abstractmethod
161 def get_H_full(self):
162 """Return full H."""
164 def get_H(self):
165 """
166 Return diagonalized H. Explicitly keep only diagonal elements of matrix.
167 """
168 return self.get_op_in_H_eigenbasis(
169 self._get_H_in_original_basis()
170 ).keep_only_diag_elements()
172 def _get_H_in_original_basis(self):
173 """This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses."""
175 if self.hamiltonian == HamiltonianTypes.linear:
176 return self.get_H_linear()
177 elif self.hamiltonian == HamiltonianTypes.full:
178 return self.get_H_full()
180 def _calculate_eig_systems(self):
181 evs, evecs = jnp.linalg.eigh(self._get_H_in_original_basis().data) # Hermitian
182 idxs_sorted = jnp.argsort(evs)
183 return evs[idxs_sorted], evecs[:, idxs_sorted]
185 @property
186 def eig_systems(self):
187 eig_systems = {}
188 eig_systems["vals"], eig_systems["vecs"] = self._calculate_eig_systems()
190 eig_systems["vecs"] = eig_systems["vecs"]
191 eig_systems["vals"] = eig_systems["vals"]
192 return eig_systems
194 def get_op_in_H_eigenbasis(self, op: Qarray):
195 evecs = self.eig_systems["vecs"][:, : self.N]
196 dims = [[self.N], [self.N]]
197 return get_op_in_new_basis(op, evecs, dims)
199 def get_op_data_in_H_eigenbasis(self, op: Array):
200 evecs = self.eig_systems["vecs"][:, : self.N]
201 return get_op_data_in_new_basis(op, evecs)
203 def get_vec_in_H_eigenbasis(self, vec: Qarray):
204 evecs = self.eig_systems["vecs"][:, : self.N]
205 if vec.qtype == Qtypes.ket:
206 dims = [[self.N], [1]]
207 else:
208 dims = [[1], [self.N]]
209 return get_vec_in_new_basis(vec, evecs, dims)
211 def get_vec_data_in_H_eigenbasis(self, vec: Array):
212 evecs = self.eig_systems["vecs"][:, : self.N]
213 return get_vec_data_in_new_basis(vec, evecs)
215 def full_ops(self):
216 # TODO: use JAX vmap here
218 linear_ops = self.linear_ops
219 ops = {}
220 for name, op in linear_ops.items():
221 ops[name] = self.get_op_in_H_eigenbasis(op)
223 return ops
226def get_op_in_new_basis(op: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray:
227 data = get_op_data_in_new_basis(op.data, evecs)
228 return Qarray.create(data, dims=dims)
231def get_op_data_in_new_basis(op_data: Array, evecs: Array) -> Array:
232 return jnp.dot(jnp.conjugate(evecs.transpose()), jnp.dot(op_data, evecs))
235def get_vec_in_new_basis(vec: Qarray, evecs: Array, dims: List[List[int]]) -> Qarray:
236 return Qarray.create(get_vec_data_in_new_basis(vec.data, evecs), dims=dims)
239def get_vec_data_in_new_basis(vec_data: Array, evecs: Array) -> Array:
240 return jnp.dot(jnp.conjugate(evecs.transpose()), vec_data)