Coverage for jaxquantum / codes / gkp.py: 99%
111 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 20:38 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 20:38 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
6import warnings
8from jaxquantum.codes.base import BosonicQubit
9import jaxquantum as jqt
11from jax import jit, lax, vmap, debug
13import jax.numpy as jnp
16class GKPQubit(BosonicQubit):
17 """
18 GKP Qubit Class.
19 """
21 name = "gkp"
23 def _params_validation(self):
24 super()._params_validation()
26 if "delta" not in self.params:
27 self.params["delta"] = 0.25
28 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi)
29 s_delta = jnp.sinh(self.params["delta"] ** 2)
30 self.params["epsilon"] = s_delta * self.params["l"]
32 def _gen_common_gates(self) -> None:
33 """
34 Overriding this method to add additional common gates.
35 """
36 super()._gen_common_gates()
38 # phase space
39 self.common_gates["x"] = (
40 self.common_gates["a_dag"] + self.common_gates["a"]
41 ) / jnp.sqrt(2.0)
42 self.common_gates["p"] = (
43 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0)
44 )
46 # finite energy
47 self.common_gates["E"] = jqt.expm(
48 -(self.params["delta"] ** 2)
49 * self.common_gates["a_dag"]
50 @ self.common_gates["a"]
51 )
52 self.common_gates["E_inv"] = jqt.expm(
53 self.params["delta"] ** 2
54 * self.common_gates["a_dag"]
55 @ self.common_gates["a"]
56 )
58 # axis
59 x_axis, z_axis = self._get_axis()
60 y_axis = x_axis + z_axis
62 # gates
63 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis)
64 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis)
65 Y_0 = 1.0j * X_0 @ Z_0
66 self.common_gates["X_0"] = X_0
67 self.common_gates["Z_0"] = Z_0
68 self.common_gates["Y_0"] = Y_0
69 self.common_gates["X"] = self._make_op_finite_energy(X_0)
70 self.common_gates["Z"] = self._make_op_finite_energy(Z_0)
71 self.common_gates["Y"] = self._make_op_finite_energy(Y_0)
73 # symmetric stabilizers and gates
74 self.common_gates["Z_s_0"] = self._symmetrized_expm(
75 1.0j * self.params["l"] / 2.0 * x_axis
76 )
77 self.common_gates["S_x_0"] = self._symmetrized_expm(
78 1.0j * self.params["l"] * z_axis
79 )
80 self.common_gates["S_z_0"] = self._symmetrized_expm(
81 1.0j * self.params["l"] * x_axis
82 )
83 self.common_gates["S_y_0"] = self._symmetrized_expm(
84 1.0j * self.params["l"] * y_axis
85 )
87 @staticmethod
88 def _q_quadrature(q_points, n):
89 q_points = q_points.T
91 F_0_init = jnp.ones_like(q_points)
92 F_1_init = jnp.sqrt(2) * q_points
94 def scan_body(n, carry):
95 F_0, F_1 = carry
96 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt(
97 (n - 1) / n) * F_0)
99 new_carry = (F_1, F_n)
101 return new_carry
103 initial_carry = (F_0_init, F_1_init)
104 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])),
105 scan_body, initial_carry)
107 q_quad = lax.select(n == 0, F_0_init,
108 lax.select(n == 1, F_1_init,
109 final_carry[1]))
111 q_quad = jnp.pi ** (-0.25) * lax.mul(
112 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad)
114 return q_quad
116 @staticmethod
117 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100):
118 """
119 Args:
120 mu: state index (0 or 1)
122 Returns:
123 GKP basis state
125 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr>
126 """
128 # We choose the truncation of our series summation such that we
129 # capture 6 sigmas of the envelope for a value of delta of 0.02.
130 # delta * (truncat_series*2*sqrt(pi)) = 6
133 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu)
135 def compute_pop(n):
136 quadvals = GKPQubit._q_quadrature(q_points, n)
137 return jnp.exp(-(delta ** 2) * n) * (
138 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0])
140 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2))
142 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype)
144 psi = psi.at[::2].set(psi_even)
146 psi = jqt.Qarray.create(jnp.array(psi)[:dim])
148 return psi.unit()
151 @staticmethod
152 def _check_delta_warning(d):
153 if d < 0.02:
154 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.")
157 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
158 """
159 Construct basis states |+-z>.
160 """
162 delta = self.params["delta"]
163 dim = self.params["N"]
165 debug.callback(GKPQubit._check_delta_warning, delta)
167 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z,
168 static_argnames=("dim",))
170 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0)
171 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1)
173 return plus_z, minus_z
175 # utils
176 # ======================================================
177 def _get_axis(self):
178 x_axis = self.common_gates["x"]
179 z_axis = -self.common_gates["p"]
180 return x_axis, z_axis
182 def _make_op_finite_energy(self, op):
183 return self.common_gates["E"] @ op @ self.common_gates["E_inv"]
185 def _symmetrized_expm(self, op):
186 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0
188 # gates
189 # ======================================================
190 @property
191 def x_U(self) -> jqt.Qarray:
192 return self.common_gates["X"]
194 @property
195 def y_U(self) -> jqt.Qarray:
196 return self.common_gates["Y"]
198 @property
199 def z_U(self) -> jqt.Qarray:
200 return self.common_gates["Z"]
203class RectangularGKPQubit(GKPQubit):
204 def _params_validation(self):
205 super()._params_validation()
206 if "a" not in self.params:
207 self.params["a"] = 0.8
209 def _get_axis(self):
210 a = self.params["a"]
211 x_axis = a * self.common_gates["x"]
212 z_axis = -1 / a * self.common_gates["p"]
213 return x_axis, z_axis
216class SquareGKPQubit(GKPQubit):
217 def _params_validation(self):
218 super()._params_validation()
219 self.params["a"] = 1.0
222class HexagonalGKPQubit(GKPQubit):
223 def _get_axis(self):
224 a = jnp.sqrt(2 / jnp.sqrt(3))
225 x_axis = a * (
226 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"]
227 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"]
228 )
229 z_axis = a * (-self.common_gates["p"])
230 return x_axis, z_axis
233## Citations
235# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States
236# Baptiste Royer, Shraddha Singh, and S. M. Girvin
237# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020
239# Quantum error correction of a qubit encoded in grid states of an oscillator.
240# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al.
241# Nature 584, 368–372 (2020).