Coverage for jaxquantum/codes/qubit.py: 50%
36 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Qubit
3"""
5from typing import Tuple
6import warnings
8from jaxquantum.codes.base import BosonicQubit
9import jaxquantum as jqt
11from jax import config
12import matplotlib.pyplot as plt
13import qutip as qt
15config.update("jax_enable_x64", True)
18class Qubit(BosonicQubit):
19 """
20 FockQubit
21 """
23 def _params_validation(self):
24 super()._params_validation()
25 self.params["N"] = 2
27 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
28 """
29 Construct basis states |+-x>, |+-y>, |+-z>
30 """
31 N = int(self.params["N"])
32 plus_z = jqt.basis(N, 0)
33 minus_z = jqt.basis(N, 1)
34 return plus_z, minus_z
36 @property
37 def x_U(self) -> jqt.Qarray:
38 return jqt.sigmax()
40 @property
41 def y_U(self) -> jqt.Qarray:
42 return jqt.sigmay()
44 @property
45 def z_U(self) -> jqt.Qarray:
46 return jqt.sigmaz()
48 def plot(self, state, ax=None, qp_type="", **kwargs) -> None:
49 state = self.jqt2qt(state)
50 with warnings.catch_warnings():
51 # TODO: suppressing deprecation warnings, deal with this
52 warnings.simplefilter("ignore")
53 b = qt.Bloch()
54 b.add_states(state)
55 b.render()
56 b.show()
57 plt.tight_layout()
58 plt.show()