Coverage for jaxquantum / codes / gkp.py: 99%
113 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 21:05 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-01-28 21:05 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
6import warnings
8from jaxquantum.codes.base import BosonicQubit
9import jaxquantum as jqt
11from jax import jit, lax, vmap, debug
13import jax.numpy as jnp
16class GKPQubit(BosonicQubit):
17 """
18 GKP Qubit Class.
19 """
21 name = "gkp"
23 def _params_validation(self):
24 super()._params_validation()
26 if "delta" not in self.params:
27 self.params["delta"] = 0.25
28 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi)
29 s_delta = jnp.sinh(self.params["delta"] ** 2)
30 self.params["epsilon"] = s_delta * self.params["l"]
31 self.params["squeezing"] = jnp.log(self.params["delta"])
32 self.params["squeezing_dB"] = 20*jnp.log10(jnp.exp(jnp.abs(self.params["squeezing"])))
34 def _gen_common_gates(self) -> None:
35 """
36 Overriding this method to add additional common gates.
37 """
38 super()._gen_common_gates()
40 # phase space
41 self.common_gates["x"] = (
42 self.common_gates["a_dag"] + self.common_gates["a"]
43 ) / jnp.sqrt(2.0)
44 self.common_gates["p"] = (
45 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0)
46 )
48 # finite energy
49 self.common_gates["E"] = jqt.expm(
50 -(self.params["delta"] ** 2)
51 * self.common_gates["a_dag"]
52 @ self.common_gates["a"]
53 )
54 self.common_gates["E_inv"] = jqt.expm(
55 self.params["delta"] ** 2
56 * self.common_gates["a_dag"]
57 @ self.common_gates["a"]
58 )
60 # axis
61 x_axis, z_axis = self._get_axis()
62 y_axis = x_axis + z_axis
64 # gates
65 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis)
66 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis)
67 Y_0 = 1.0j * X_0 @ Z_0
68 self.common_gates["X_0"] = X_0
69 self.common_gates["Z_0"] = Z_0
70 self.common_gates["Y_0"] = Y_0
71 self.common_gates["X"] = self._make_op_finite_energy(X_0)
72 self.common_gates["Z"] = self._make_op_finite_energy(Z_0)
73 self.common_gates["Y"] = self._make_op_finite_energy(Y_0)
75 # symmetric stabilizers and gates
76 self.common_gates["Z_s_0"] = self._symmetrized_expm(
77 1.0j * self.params["l"] / 2.0 * x_axis
78 )
79 self.common_gates["S_x_0"] = self._symmetrized_expm(
80 1.0j * self.params["l"] * z_axis
81 )
82 self.common_gates["S_z_0"] = self._symmetrized_expm(
83 1.0j * self.params["l"] * x_axis
84 )
85 self.common_gates["S_y_0"] = self._symmetrized_expm(
86 1.0j * self.params["l"] * y_axis
87 )
89 @staticmethod
90 def _q_quadrature(q_points, n):
91 q_points = q_points.T
93 F_0_init = jnp.ones_like(q_points)
94 F_1_init = jnp.sqrt(2) * q_points
96 def scan_body(n, carry):
97 F_0, F_1 = carry
98 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt(
99 (n - 1) / n) * F_0)
101 new_carry = (F_1, F_n)
103 return new_carry
105 initial_carry = (F_0_init, F_1_init)
106 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])),
107 scan_body, initial_carry)
109 q_quad = lax.select(n == 0, F_0_init,
110 lax.select(n == 1, F_1_init,
111 final_carry[1]))
113 q_quad = jnp.pi ** (-0.25) * lax.mul(
114 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad)
116 return q_quad
118 @staticmethod
119 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100):
120 """
121 Args:
122 mu: state index (0 or 1)
124 Returns:
125 GKP basis state
127 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr>
128 """
130 # We choose the truncation of our series summation such that we
131 # capture 6 sigmas of the envelope for a value of delta of 0.02.
132 # delta * (truncat_series*2*sqrt(pi)) = 6
135 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu)
137 def compute_pop(n):
138 quadvals = GKPQubit._q_quadrature(q_points, n)
139 return jnp.exp(-(delta ** 2) * n) * (
140 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0])
142 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2))
144 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype)
146 psi = psi.at[::2].set(psi_even)
148 psi = jqt.Qarray.create(jnp.array(psi)[:dim])
150 return psi.unit()
153 @staticmethod
154 def _check_delta_warning(d):
155 if d < 0.02:
156 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.")
159 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
160 """
161 Construct basis states |+-z>.
162 """
164 delta = self.params["delta"]
165 dim = self.params["N"]
167 debug.callback(GKPQubit._check_delta_warning, delta)
169 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z,
170 static_argnames=("dim",))
172 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0)
173 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1)
175 return plus_z, minus_z
177 # utils
178 # ======================================================
179 def _get_axis(self):
180 x_axis = self.common_gates["x"]
181 z_axis = -self.common_gates["p"]
182 return x_axis, z_axis
184 def _make_op_finite_energy(self, op):
185 return self.common_gates["E"] @ op @ self.common_gates["E_inv"]
187 def _symmetrized_expm(self, op):
188 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0
190 # gates
191 # ======================================================
192 @property
193 def x_U(self) -> jqt.Qarray:
194 return self.common_gates["X"]
196 @property
197 def y_U(self) -> jqt.Qarray:
198 return self.common_gates["Y"]
200 @property
201 def z_U(self) -> jqt.Qarray:
202 return self.common_gates["Z"]
205class RectangularGKPQubit(GKPQubit):
206 def _params_validation(self):
207 super()._params_validation()
208 if "a" not in self.params:
209 self.params["a"] = 0.8
211 def _get_axis(self):
212 a = self.params["a"]
213 x_axis = a * self.common_gates["x"]
214 z_axis = -1 / a * self.common_gates["p"]
215 return x_axis, z_axis
218class SquareGKPQubit(GKPQubit):
219 def _params_validation(self):
220 super()._params_validation()
221 self.params["a"] = 1.0
224class HexagonalGKPQubit(GKPQubit):
225 def _get_axis(self):
226 a = jnp.sqrt(2 / jnp.sqrt(3))
227 x_axis = a * (
228 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"]
229 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"]
230 )
231 z_axis = a * (-self.common_gates["p"])
232 return x_axis, z_axis
235## Citations
237# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States
238# Baptiste Royer, Shraddha Singh, and S. M. Girvin
239# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020
241# Quantum error correction of a qubit encoded in grid states of an oscillator.
242# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al.
243# Nature 584, 368–372 (2020).