Coverage for jaxquantum / codes / cat.py: 100%
23 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
7from jaxquantum.codes.base import BosonicQubit
8import jaxquantum as jqt
10from jax import config
12config.update("jax_enable_x64", True)
15class CatQubit(BosonicQubit):
16 """
17 Cat Qubit Class.
18 """
20 PARAMETERS = ["alpha"]
22 name = "cat"
24 @property
25 def _non_device_params(self):
26 param_list = super()._non_device_params
27 param_list.append("alpha")
28 return param_list
30 def _params_validation(self):
31 super()._params_validation()
32 if "alpha" not in self.params:
33 self.params["alpha"] = 2
35 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
36 """
37 Construct basis states |+-x>, |+-y>, |+-z>
38 """
39 N = self.params["N"]
40 a = self.params["alpha"]
41 plus_z = jqt.unit(jqt.coherent(N, a) + jqt.coherent(N, -1.0 * a))
42 minus_z = jqt.unit(jqt.coherent(N, 1.0j * a) + jqt.coherent(N, -1.0j * a))
43 return plus_z, minus_z