Coverage for jaxquantum/codes/binomial.py: 33%
36 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
8from jaxquantum.utils.utils import comb
9from jaxquantum.codes.base import BosonicQubit
10import jaxquantum as jqt
12from jax import vmap
13from jax import config
14import jax.numpy as jnp
16config.update("jax_enable_x64", True)
19class BinomialQubit(BosonicQubit):
20 """
21 Cat Qubit Class.
22 """
24 name = "binomial"
26 def _params_validation(self):
27 super()._params_validation()
29 # notation https://arxiv.org/pdf/2010.08699.pdf
30 if "L" not in self.params:
31 self.params["L"] = 1
32 if "G" not in self.params:
33 self.params["G"] = 0
34 if "D" not in self.params:
35 self.params["D"] = 0
37 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
38 """
39 Construct basis states |+-x>, |+-y>, |+-z>
40 """
41 N = self.params["N"]
43 L = self.params["L"]
44 G = self.params["G"]
45 D = self.params["D"]
47 S = L + G
49 M = jnp.max(jnp.array([L, G, 2 * D]))
51 def plus_z_gen(p):
52 C = comb(M + 1, p)
53 return jnp.sqrt(C) * jqt.basis(N, p * (S + 1)).data
55 plus_z = jnp.sum(vmap(plus_z_gen)(jnp.arange(0, M + 2, 2)), axis=0)
56 plus_z = jqt.unit(jqt.Qarray.create(plus_z))
58 def minus_z_gen(p):
59 C = comb(M + 1, p)
60 return jnp.sqrt(C) * jqt.basis(N, p * (S + 1)).data
62 minus_z = jnp.sum(vmap(minus_z_gen)(jnp.arange(1, M + 2, 2)), axis=0)
63 minus_z = jqt.unit(jqt.Qarray.create(minus_z))
65 return plus_z, minus_z