Coverage for jaxquantum / codes / base.py: 85%
130 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
1"""
2Base Bosonic Qubit Class
3"""
5from typing import Dict, Optional, Tuple
6from abc import abstractmethod, ABCMeta
8from jaxquantum.utils.utils import device_put_params
9import jaxquantum as jqt
11from jax import config
12import numpy as np
13import jax.numpy as jnp
14import matplotlib.pyplot as plt
16config.update("jax_enable_x64", True)
19class BosonicQubit(metaclass=ABCMeta):
20 """
21 Base class for Bosonic Qubits.
22 """
24 BASE_PARAMETERS = ["N"]
25 PARAMETERS = []
27 name = "bqubit"
29 @property
30 def _non_device_params(self):
31 """
32 Can be overriden in child classes.
33 """
34 return ["N"]
36 def __init__(self, params: Optional[Dict[str, float]] = None, name: str = None):
37 if name is not None:
38 self.name = name
40 self.params = params if params else {}
41 self._params_validation()
43 self.params = device_put_params(self.params, self._non_device_params)
45 self.common_gates: Dict[str, jqt.Qarray] = {}
46 self._gen_common_gates()
48 self.wigner_pts = jnp.linspace(-4.5, 4.5, 61)
50 self.basis = self._get_basis_states()
52 for basis_state in ["+x", "-x", "+y", "-y", "+z", "-z"]:
53 assert basis_state in self.basis, (
54 f"Please set the {basis_state} basis state."
55 )
57 def _params_validation(self):
58 """
59 Override this method to add additional validation to params.
61 E.g.
62 if "N" not in self.params:
63 self.params["N"] = 50
64 """
66 for key in self.params:
67 if key not in self.BASE_PARAMETERS + self.PARAMETERS:
68 raise ValueError(
69 f"Invalid parameter {key}. Allowed parameters are {self.BASE_PARAMETERS + self.PARAMETERS}"
70 )
72 if "N" not in self.params:
73 self.params["N"] = 50
75 def _gen_common_gates(self):
76 """
77 Override this method to add additional common gates.
79 E.g.
80 if "N" not in self.params:
81 self.params["N"] = 50
82 """
83 N = self.params["N"]
84 self.common_gates["a_dag"] = jqt.create(N)
85 self.common_gates["a"] = jqt.destroy(N)
87 @abstractmethod
88 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
89 """
90 Returns:
91 plus_z (jqt.Qarray), minus_z (jqt.Qarray): z basis states
92 """
94 def _get_basis_states(self) -> Dict[str, jqt.Qarray]:
95 """
96 Construct basis states |+-x>, |+-y>, |+-z>
97 """
98 plus_z, minus_z = self._get_basis_z()
99 return self._gen_basis_states_from_z(plus_z, minus_z)
101 def _gen_basis_states_from_z(
102 self, plus_z: jqt.Qarray, minus_z: jqt.Qarray
103 ) -> Dict[str, jqt.Qarray]:
104 """
105 Construct basis states |+-x>, |+-y>, |+-z> from |+-z>
106 """
107 basis: Dict[str, jqt.Qarray] = {}
109 # import to make sure that each basis state is a column vec
110 # otherwise, transposing a 1D vector will do nothing
112 basis["+z"] = plus_z
113 basis["-z"] = minus_z
115 basis["+x"] = jqt.unit(basis["+z"] + basis["-z"])
116 basis["-x"] = jqt.unit(basis["+z"] - basis["-z"])
117 basis["+y"] = jqt.unit(basis["+z"] + 1j * basis["-z"])
118 basis["-y"] = jqt.unit(basis["+z"] - 1j * basis["-z"])
119 return basis
121 def jqt2qt(self, state):
122 return jqt.jqt2qt(state)
124 # gates
125 # ======================================================
126 # @abstractmethod
127 # def stabilize(self) -> None:
128 # """
129 # Stabilizing/measuring syndromes.
130 # """
132 @property
133 def x_U(self) -> jqt.Qarray:
134 """
135 Logical X unitary gate.
136 """
137 return self._gen_pauli_U("x")
139 @property
140 def x_H(self) -> Optional[jqt.Qarray]:
141 """
142 Logical X hamiltonian.
143 """
144 return None
146 @property
147 def y_U(self) -> jqt.Qarray:
148 """
149 Logical Y unitary gate.
150 """
151 return self._gen_pauli_U("y")
153 @property
154 def y_H(self) -> Optional[jqt.Qarray]:
155 """
156 Logical Y hamiltonian.
157 """
158 return None
160 @property
161 def z_U(self) -> jqt.Qarray:
162 """
163 Logical Z unitary gate.
164 """
165 return self._gen_pauli_U("z")
167 @property
168 def z_H(self) -> Optional[jqt.Qarray]:
169 """
170 Logical Z hamiltonian.
171 """
172 return None
174 @property
175 def h_H(self) -> Optional[jqt.Qarray]:
176 """
177 Logical Hadamard hamiltonian.
178 """
179 return None
181 @property
182 def h_U(self) -> jqt.Qarray:
183 """
184 Logical Hadamard unitary gate.
185 """
186 return (
187 self.basis["+x"] @ self.basis["+z"].dag()
188 + self.basis["-x"] @ self.basis["-z"].dag()
189 )
191 def _gen_pauli_U(self, basis_state: str) -> jqt.Qarray:
192 """
193 Generates unitary for Pauli X, Y, Z.
195 Args:
196 basis_state (str): "x", "y", "z"
198 Returns:
199 U (jqt.Qarray): Pauli unitary
200 """
201 H = getattr(self, basis_state + "_H")
202 if H is not None:
203 return jqt.expm(1.0j * H)
205 gate = (
206 self.basis["+" + basis_state] @ self.basis["+" + basis_state].dag()
207 - self.basis["-" + basis_state] @ self.basis["-" + basis_state].dag()
208 )
210 return gate
212 @property
213 def projector(self):
214 return (
215 self.basis["+z"] @ self.basis["+z"].dag()
216 + self.basis["-z"] @ self.basis["-z"].dag()
217 )
219 @property
220 def maximally_mixed_state(self):
221 return (1 / 2.0) * self.projector
223 # Plotting
224 # ======================================================
225 def _prepare_state_plot(self, state):
226 """
227 Can be overriden.
229 E.g. in the case of cavity x transmon system
230 return qt.ptrace(state, 0)
231 """
232 return state
234 def plot(self, state, ax=None, qp_type=jqt.WIGNER, **kwargs) -> None:
235 if ax is None:
236 fig, ax = plt.subplots(1, figsize=(4, 3), dpi=200)
237 fig = ax.get_figure()
239 if qp_type == jqt.WIGNER:
240 vmin = -1
241 vmax = 1
242 elif qp_type == jqt.QFUNC:
243 vmin = 0
244 vmax = 1
246 w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs)
248 ax.set_title(qp_type.capitalize() + " Quasi-Probability Dist.")
249 ticks = np.linspace(vmin, vmax, 5)
250 fig.colorbar(w_plt, ax=ax, ticks=ticks)
251 ax.set_xlabel(r"Re$(\alpha)$")
252 ax.set_ylabel(r"Im$(\alpha)$")
253 fig.tight_layout()
255 plt.show()
257 def _plot_single(self, state, ax=None, contour=True, qp_type=jqt.WIGNER):
258 """
259 Assumes state has same dims as initial_state.
260 """
262 if ax is None:
263 _, ax = plt.subplots(1, figsize=(4, 3), dpi=200)
265 return jqt.plot_qp(
266 state, self.wigner_pts, axs=ax, contour=contour, qp_type=qp_type
267 )
269 def plot_code_states(self, qp_type: str = jqt.WIGNER, **kwargs):
270 """
271 Plot |±x⟩, |±y⟩, |±z⟩ code states.
273 Args:
274 qp_type (str):
275 WIGNER or QFUNC
277 Return:
278 axs: Axes
279 """
280 fig, axs = plt.subplots(2, 3, figsize=(12, 6), dpi=200)
282 for i, label in enumerate(["+z", "+x", "+y", "-z", "-x", "-y"]):
283 state = self._prepare_state_plot(self.basis[label])
284 pos = (i // 3, i % 3)
285 ax = axs[pos]
286 _, w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs)
287 ax.set_title(f"|{label}" + r"$\rangle$")
288 ax.set_xlabel(r"Re[$\alpha$]")
289 ax.set_ylabel(r"Im[$\alpha$]")
291 fig.suptitle(self.name)
292 fig.tight_layout()
293 fig.align_xlabels(axs)
294 fig.align_ylabels(axs)
296 fig.tight_layout()
297 plt.show()