Coverage for jaxquantum/codes/gkp.py: 99%
108 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
6import warnings
8from jaxquantum.codes.base import BosonicQubit
9import jaxquantum as jqt
11from jax import jit, lax, vmap
13import jax.numpy as jnp
16class GKPQubit(BosonicQubit):
17 """
18 GKP Qubit Class.
19 """
21 name = "gkp"
23 def _params_validation(self):
24 super()._params_validation()
26 if "delta" not in self.params:
27 self.params["delta"] = 0.25
28 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi)
29 s_delta = jnp.sinh(self.params["delta"] ** 2)
30 self.params["epsilon"] = s_delta * self.params["l"]
32 def _gen_common_gates(self) -> None:
33 """
34 Overriding this method to add additional common gates.
35 """
36 super()._gen_common_gates()
38 # phase space
39 self.common_gates["x"] = (
40 self.common_gates["a_dag"] + self.common_gates["a"]
41 ) / jnp.sqrt(2.0)
42 self.common_gates["p"] = (
43 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0)
44 )
46 # finite energy
47 self.common_gates["E"] = jqt.expm(
48 -(self.params["delta"] ** 2)
49 * self.common_gates["a_dag"]
50 @ self.common_gates["a"]
51 )
52 self.common_gates["E_inv"] = jqt.expm(
53 self.params["delta"] ** 2
54 * self.common_gates["a_dag"]
55 @ self.common_gates["a"]
56 )
58 # axis
59 x_axis, z_axis = self._get_axis()
60 y_axis = x_axis + z_axis
62 # gates
63 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis)
64 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis)
65 Y_0 = 1.0j * X_0 @ Z_0
66 self.common_gates["X_0"] = X_0
67 self.common_gates["Z_0"] = Z_0
68 self.common_gates["Y_0"] = Y_0
69 self.common_gates["X"] = self._make_op_finite_energy(X_0)
70 self.common_gates["Z"] = self._make_op_finite_energy(Z_0)
71 self.common_gates["Y"] = self._make_op_finite_energy(Y_0)
73 # symmetric stabilizers and gates
74 self.common_gates["Z_s_0"] = self._symmetrized_expm(
75 1.0j * self.params["l"] / 2.0 * x_axis
76 )
77 self.common_gates["S_x_0"] = self._symmetrized_expm(
78 1.0j * self.params["l"] * z_axis
79 )
80 self.common_gates["S_z_0"] = self._symmetrized_expm(
81 1.0j * self.params["l"] * x_axis
82 )
83 self.common_gates["S_y_0"] = self._symmetrized_expm(
84 1.0j * self.params["l"] * y_axis
85 )
87 @staticmethod
88 def _q_quadrature(q_points, n):
89 q_points = q_points.T
91 F_0_init = jnp.ones_like(q_points)
92 F_1_init = jnp.sqrt(2) * q_points
94 def scan_body(n, carry):
95 F_0, F_1 = carry
96 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt(
97 (n - 1) / n) * F_0)
99 new_carry = (F_1, F_n)
101 return new_carry
103 initial_carry = (F_0_init, F_1_init)
104 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])),
105 scan_body, initial_carry)
107 q_quad = lax.select(n == 0, F_0_init,
108 lax.select(n == 1, F_1_init,
109 final_carry[1]))
111 q_quad = jnp.pi ** (-0.25) * lax.mul(
112 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad)
114 return q_quad
116 @staticmethod
117 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100):
118 """
119 Args:
120 mu: state index (0 or 1)
122 Returns:
123 GKP basis state
125 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr>
126 """
128 # We choose the truncation of our series summation such that we
129 # capture 6 sigmas of the envelope for a value of delta of 0.02.
130 # delta * (truncat_series*2*sqrt(pi)) = 6
135 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu)
137 def compute_pop(n):
138 quadvals = GKPQubit._q_quadrature(q_points, n)
139 return jnp.exp(-(delta ** 2) * n) * (
140 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0])
142 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2))
144 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype)
146 psi = psi.at[::2].set(psi_even)
148 psi = jqt.Qarray.create(jnp.array(psi))
150 return psi.unit()
155 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
156 """
157 Construct basis states |+-z>.
158 """
160 delta = self.params["delta"]
161 dim = self.params["N"]
163 if delta<0.02:
164 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.")
166 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z,
167 static_argnames=("dim",))
169 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0)
170 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1)
172 return plus_z, minus_z
174 # utils
175 # ======================================================
176 def _get_axis(self):
177 x_axis = self.common_gates["x"]
178 z_axis = -self.common_gates["p"]
179 return x_axis, z_axis
181 def _make_op_finite_energy(self, op):
182 return self.common_gates["E"] @ op @ self.common_gates["E_inv"]
184 def _symmetrized_expm(self, op):
185 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0
187 # gates
188 # ======================================================
189 @property
190 def x_U(self) -> jqt.Qarray:
191 return self.common_gates["X"]
193 @property
194 def y_U(self) -> jqt.Qarray:
195 return self.common_gates["Y"]
197 @property
198 def z_U(self) -> jqt.Qarray:
199 return self.common_gates["Z"]
202class RectangularGKPQubit(GKPQubit):
203 def _params_validation(self):
204 super()._params_validation()
205 if "a" not in self.params:
206 self.params["a"] = 0.8
208 def _get_axis(self):
209 a = self.params["a"]
210 x_axis = a * self.common_gates["x"]
211 z_axis = -1 / a * self.common_gates["p"]
212 return x_axis, z_axis
215class SquareGKPQubit(GKPQubit):
216 def _params_validation(self):
217 super()._params_validation()
218 self.params["a"] = 1.0
221class HexagonalGKPQubit(GKPQubit):
222 def _get_axis(self):
223 a = jnp.sqrt(2 / jnp.sqrt(3))
224 x_axis = a * (
225 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"]
226 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"]
227 )
228 z_axis = a * (-self.common_gates["p"])
229 return x_axis, z_axis
232## Citations
234# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States
235# Baptiste Royer, Shraddha Singh, and S. M. Girvin
236# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020
238# Quantum error correction of a qubit encoded in grid states of an oscillator.
239# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al.
240# Nature 584, 368–372 (2020).