Coverage for jaxquantum/codes/cat.py: 50%
22 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
7from jaxquantum.codes.base import BosonicQubit
8import jaxquantum as jqt
10from jax import config
12config.update("jax_enable_x64", True)
15class CatQubit(BosonicQubit):
16 """
17 Cat Qubit Class.
18 """
20 name = "cat"
22 @property
23 def _non_device_params(self):
24 param_list = super()._non_device_params
25 param_list.append("alpha")
26 return param_list
28 def _params_validation(self):
29 super()._params_validation()
30 if "alpha" not in self.params:
31 self.params["alpha"] = 2
33 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
34 """
35 Construct basis states |+-x>, |+-y>, |+-z>
36 """
37 N = self.params["N"]
38 a = self.params["alpha"]
39 plus_z = jqt.unit(jqt.coherent(N, a) + jqt.coherent(N, -1.0 * a))
40 minus_z = jqt.unit(jqt.coherent(N, 1.0j * a) + jqt.coherent(N, -1.0j * a))
41 return plus_z, minus_z