Coverage for jaxquantum/codes/base.py: 51%
139 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Base Bosonic Qubit Class
3"""
5from typing import Dict, Optional, Tuple
6from abc import abstractmethod, ABCMeta
8from jaxquantum.utils.utils import device_put_params
9import jaxquantum as jqt
11from jax import config
12import numpy as np
13import jax.numpy as jnp
14import matplotlib.pyplot as plt
16config.update("jax_enable_x64", True)
19class BosonicQubit(metaclass=ABCMeta):
20 """
21 Base class for Bosonic Qubits.
22 """
24 name = "bqubit"
26 @property
27 def _non_device_params(self):
28 """
29 Can be overriden in child classes.
30 """
31 return ["N"]
33 def __init__(self, params: Optional[Dict[str, float]] = None, name: str = None):
34 if name is not None:
35 self.name = name
37 self.params = params if params else {}
38 self._params_validation()
40 self.params = device_put_params(self.params, self._non_device_params)
42 self.common_gates: Dict[str, jqt.Qarray] = {}
43 self._gen_common_gates()
45 self.wigner_pts = jnp.linspace(-4.5, 4.5, 61)
47 self.basis = self._get_basis_states()
49 for basis_state in ["+x", "-x", "+y", "-y", "+z", "-z"]:
50 assert basis_state in self.basis, (
51 f"Please set the {basis_state} basis state."
52 )
54 def _params_validation(self):
55 """
56 Override this method to add additional validation to params.
58 E.g.
59 if "N" not in self.params:
60 self.params["N"] = 50
61 """
62 if "N" not in self.params:
63 self.params["N"] = 50
65 def _gen_common_gates(self):
66 """
67 Override this method to add additional common gates.
69 E.g.
70 if "N" not in self.params:
71 self.params["N"] = 50
72 """
73 N = self.params["N"]
74 self.common_gates["a_dag"] = jqt.create(N)
75 self.common_gates["a"] = jqt.destroy(N)
77 @abstractmethod
78 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
79 """
80 Returns:
81 plus_z (jqt.Qarray), minus_z (jqt.Qarray): z basis states
82 """
84 def _get_basis_states(self) -> Dict[str, jqt.Qarray]:
85 """
86 Construct basis states |+-x>, |+-y>, |+-z>
87 """
88 plus_z, minus_z = self._get_basis_z()
89 return self._gen_basis_states_from_z(plus_z, minus_z)
91 def _gen_basis_states_from_z(
92 self, plus_z: jqt.Qarray, minus_z: jqt.Qarray
93 ) -> Dict[str, jqt.Qarray]:
94 """
95 Construct basis states |+-x>, |+-y>, |+-z> from |+-z>
96 """
97 basis: Dict[str, jqt.Qarray] = {}
99 # import to make sure that each basis state is a column vec
100 # otherwise, transposing a 1D vector will do nothing
102 basis["+z"] = plus_z
103 basis["-z"] = minus_z
105 basis["+x"] = jqt.unit(basis["+z"] + basis["-z"])
106 basis["-x"] = jqt.unit(basis["+z"] - basis["-z"])
107 basis["+y"] = jqt.unit(basis["+z"] + 1j * basis["-z"])
108 basis["-y"] = jqt.unit(basis["+z"] - 1j * basis["-z"])
109 return basis
111 def jqt2qt(self, state):
112 return jqt.jqt2qt(state)
114 # gates
115 # ======================================================
116 # @abstractmethod
117 # def stabilize(self) -> None:
118 # """
119 # Stabilizing/measuring syndromes.
120 # """
122 @property
123 def x_U(self) -> jqt.Qarray:
124 """
125 Logical X unitary gate.
126 """
127 return self._gen_pauli_U("x")
129 @property
130 def x_H(self) -> Optional[jqt.Qarray]:
131 """
132 Logical X hamiltonian.
133 """
134 return None
136 @property
137 def y_U(self) -> jqt.Qarray:
138 """
139 Logical Y unitary gate.
140 """
141 return self._gen_pauli_U("y")
143 @property
144 def y_H(self) -> Optional[jqt.Qarray]:
145 """
146 Logical Y hamiltonian.
147 """
148 return None
150 @property
151 def z_U(self) -> jqt.Qarray:
152 """
153 Logical Z unitary gate.
154 """
155 return self._gen_pauli_U("z")
157 @property
158 def z_H(self) -> Optional[jqt.Qarray]:
159 """
160 Logical Z hamiltonian.
161 """
162 return None
164 @property
165 def h_H(self) -> Optional[jqt.Qarray]:
166 """
167 Logical Hadamard hamiltonian.
168 """
169 return None
171 @property
172 def h_U(self) -> jqt.Qarray:
173 """
174 Logical Hadamard unitary gate.
175 """
176 return (
177 self.basis["+x"] @ self.basis["+z"].dag()
178 + self.basis["-x"] @ self.basis["-z"].dag()
179 )
181 def _gen_pauli_U(self, basis_state: str) -> jqt.Qarray:
182 """
183 Generates unitary for Pauli X, Y, Z.
185 Args:
186 basis_state (str): "x", "y", "z"
188 Returns:
189 U (jqt.Qarray): Pauli unitary
190 """
191 H = getattr(self, basis_state + "_H")
192 if H is not None:
193 return jqt.expm(1.0j * H)
195 gate = (
196 self.basis["+" + basis_state] @ self.basis["+" + basis_state].dag()
197 - self.basis["-" + basis_state] @ self.basis["-" + basis_state].dag()
198 )
200 return gate
202 @property
203 def projector(self):
204 return (
205 self.basis["+z"] @ self.basis["+z"].dag()
206 + self.basis["-z"] @ self.basis["-z"].dag()
207 )
209 @property
210 def maximally_mixed_state(self):
211 return (1 / 2.0) * self.projector()
213 # Plotting
214 # ======================================================
215 def _prepare_state_plot(self, state):
216 """
217 Can be overriden.
219 E.g. in the case of cavity x transmon system
220 return qt.ptrace(state, 0)
221 """
222 return state
224 def plot(self, state, ax=None, qp_type=jqt.WIGNER, **kwargs) -> None:
225 if ax is None:
226 fig, ax = plt.subplots(1, figsize=(4, 3), dpi=200)
227 fig = ax.get_figure()
229 if qp_type == jqt.WIGNER:
230 vmin = -1
231 vmax = 1
232 elif qp_type == jqt.QFUNC:
233 vmin = 0
234 vmax = 1
236 w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs)
238 ax.set_title(qp_type.capitalize() + " Quasi-Probability Dist.")
239 ticks = np.linspace(vmin, vmax, 5)
240 fig.colorbar(w_plt, ax=ax, ticks=ticks)
241 ax.set_xlabel(r"Re$(\alpha)$")
242 ax.set_ylabel(r"Im$(\alpha)$")
243 fig.tight_layout()
245 plt.show()
247 def _plot_single(self, state, ax=None, contour=True, qp_type=jqt.WIGNER):
248 """
249 Assumes state has same dims as initial_state.
250 """
251 state = self.jqt2qt(state)
253 if ax is None:
254 _, ax = plt.subplots(1, figsize=(4, 3), dpi=200)
256 return jqt.plot_qp(
257 state, self.wigner_pts, ax=ax, contour=contour, qp_type=qp_type
258 )
260 def plot_code_states(self, qp_type: str = jqt.WIGNER, **kwargs):
261 """
262 Plot |±x⟩, |±y⟩, |±z⟩ code states.
264 Args:
265 qp_type (str):
266 WIGNER or QFUNC
268 Return:
269 axs: Axes
270 """
271 fig, axs = plt.subplots(2, 3, figsize=(12, 6), dpi=200)
272 if qp_type == jqt.WIGNER:
273 cbar_title = r"$\frac{\pi}{2} W(\alpha)$"
274 vmin = -1
275 vmax = 1
276 elif qp_type == jqt.QFUNC:
277 cbar_title = r"$\pi Q(\alpha)$"
278 vmin = 0
279 vmax = 1
281 for i, label in enumerate(["+z", "+x", "+y", "-z", "-x", "-y"]):
282 state = self._prepare_state_plot(self.basis[label])
283 pos = (i // 3, i % 3)
284 ax = axs[pos]
285 _, w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs)
286 ax.set_title(f"|{label}" + r"$\rangle$")
287 ax.set_xlabel(r"Re[$\alpha$]")
288 ax.set_ylabel(r"Im[$\alpha$]")
290 fig.suptitle(self.name)
291 fig.tight_layout()
292 fig.subplots_adjust(right=0.8, hspace=0.2, wspace=0.2)
293 fig.align_xlabels(axs)
294 fig.align_ylabels(axs)
295 cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
297 ticks = np.linspace(vmin, vmax, 5)
298 fig.colorbar(w_plt, cax=cbar_ax, ticks=ticks)
300 cbar_ax.set_title(cbar_title, pad=20)
301 fig.tight_layout()
302 plt.show()