Coverage for jaxquantum / codes / gkp.py: 99%
115 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
6import warnings
8from jaxquantum.codes.base import BosonicQubit
9import jaxquantum as jqt
11from jax import jit, lax, vmap, debug
13import jax.numpy as jnp
16class GKPQubit(BosonicQubit):
17 """
18 GKP Qubit Class.
19 """
21 PARAMETERS = ["delta"]
23 name = "gkp"
25 def _params_validation(self):
26 super()._params_validation()
28 if "delta" not in self.params:
29 self.params["delta"] = 0.25
31 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi)
32 s_delta = jnp.sinh(self.params["delta"] ** 2)
33 self.params["epsilon"] = s_delta * self.params["l"]
34 self.params["squeezing"] = jnp.log(self.params["delta"])
35 self.params["squeezing_dB"] = 20*jnp.log10(jnp.exp(jnp.abs(self.params["squeezing"])))
37 def _gen_common_gates(self) -> None:
38 """
39 Overriding this method to add additional common gates.
40 """
41 super()._gen_common_gates()
43 # phase space
44 self.common_gates["x"] = (
45 self.common_gates["a_dag"] + self.common_gates["a"]
46 ) / jnp.sqrt(2.0)
47 self.common_gates["p"] = (
48 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0)
49 )
51 # finite energy
52 self.common_gates["E"] = jqt.expm(
53 -(self.params["delta"] ** 2)
54 * self.common_gates["a_dag"]
55 @ self.common_gates["a"]
56 )
57 self.common_gates["E_inv"] = jqt.expm(
58 self.params["delta"] ** 2
59 * self.common_gates["a_dag"]
60 @ self.common_gates["a"]
61 )
63 # axis
64 x_axis, z_axis = self._get_axis()
65 y_axis = x_axis + z_axis
67 # gates
68 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis)
69 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis)
70 Y_0 = 1.0j * X_0 @ Z_0
71 self.common_gates["X_0"] = X_0
72 self.common_gates["Z_0"] = Z_0
73 self.common_gates["Y_0"] = Y_0
74 self.common_gates["X"] = self._make_op_finite_energy(X_0)
75 self.common_gates["Z"] = self._make_op_finite_energy(Z_0)
76 self.common_gates["Y"] = self._make_op_finite_energy(Y_0)
78 # symmetric stabilizers and gates
79 self.common_gates["Z_s_0"] = self._symmetrized_expm(
80 1.0j * self.params["l"] / 2.0 * x_axis
81 )
82 self.common_gates["S_x_0"] = self._symmetrized_expm(
83 1.0j * self.params["l"] * z_axis
84 )
85 self.common_gates["S_z_0"] = self._symmetrized_expm(
86 1.0j * self.params["l"] * x_axis
87 )
88 self.common_gates["S_y_0"] = self._symmetrized_expm(
89 1.0j * self.params["l"] * y_axis
90 )
92 @staticmethod
93 def _q_quadrature(q_points, n):
94 q_points = q_points.T
96 F_0_init = jnp.ones_like(q_points)
97 F_1_init = jnp.sqrt(2) * q_points
99 def scan_body(n, carry):
100 F_0, F_1 = carry
101 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt(
102 (n - 1) / n) * F_0)
104 new_carry = (F_1, F_n)
106 return new_carry
108 initial_carry = (F_0_init, F_1_init)
109 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])),
110 scan_body, initial_carry)
112 q_quad = lax.select(n == 0, F_0_init,
113 lax.select(n == 1, F_1_init,
114 final_carry[1]))
116 q_quad = jnp.pi ** (-0.25) * lax.mul(
117 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad)
119 return q_quad
121 @staticmethod
122 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100):
123 """
124 Args:
125 mu: state index (0 or 1)
127 Returns:
128 GKP basis state
130 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr>
131 """
133 # We choose the truncation of our series summation such that we
134 # capture 6 sigmas of the envelope for a value of delta of 0.02.
135 # delta * (truncat_series*2*sqrt(pi)) = 6
138 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu)
140 def compute_pop(n):
141 quadvals = GKPQubit._q_quadrature(q_points, n)
142 return jnp.exp(-(delta ** 2) * n) * (
143 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0])
145 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2))
147 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype)
149 psi = psi.at[::2].set(psi_even)
151 psi = jqt.Qarray.create(jnp.array(psi)[:dim])
153 return psi.unit()
156 @staticmethod
157 def _check_delta_warning(d):
158 if d < 0.02:
159 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.")
162 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
163 """
164 Construct basis states |+-z>.
165 """
167 delta = self.params["delta"]
168 dim = self.params["N"]
170 debug.callback(GKPQubit._check_delta_warning, delta)
172 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z,
173 static_argnames=("dim",))
175 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0)
176 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1)
178 return plus_z, minus_z
180 # utils
181 # ======================================================
182 def _get_axis(self):
183 x_axis = self.common_gates["x"]
184 z_axis = -self.common_gates["p"]
185 return x_axis, z_axis
187 def _make_op_finite_energy(self, op):
188 return self.common_gates["E"] @ op @ self.common_gates["E_inv"]
190 def _symmetrized_expm(self, op):
191 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0
193 # gates
194 # ======================================================
195 @property
196 def x_U(self) -> jqt.Qarray:
197 return self.common_gates["X"]
199 @property
200 def y_U(self) -> jqt.Qarray:
201 return self.common_gates["Y"]
203 @property
204 def z_U(self) -> jqt.Qarray:
205 return self.common_gates["Z"]
208class RectangularGKPQubit(GKPQubit):
210 PARAMETERS = ["delta", "a"]
212 def _params_validation(self):
213 super()._params_validation()
214 if "a" not in self.params:
215 self.params["a"] = 0.8
217 def _get_axis(self):
218 a = self.params["a"]
219 x_axis = a * self.common_gates["x"]
220 z_axis = -1 / a * self.common_gates["p"]
221 return x_axis, z_axis
224class SquareGKPQubit(GKPQubit):
226 def _params_validation(self):
227 super()._params_validation()
228 self.params["a"] = 1.0
231class HexagonalGKPQubit(GKPQubit):
233 def _get_axis(self):
234 a = jnp.sqrt(2 / jnp.sqrt(3))
235 x_axis = a * (
236 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"]
237 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"]
238 )
239 z_axis = a * (-self.common_gates["p"])
240 return x_axis, z_axis
243## Citations
245# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States
246# Baptiste Royer, Shraddha Singh, and S. M. Girvin
247# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020
249# Quantum error correction of a qubit encoded in grid states of an oscillator.
250# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al.
251# Nature 584, 368–372 (2020).