Coverage for jaxquantum / codes / gkp.py: 99%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 21:51 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6import warnings 

7 

8from jaxquantum.codes.base import BosonicQubit 

9import jaxquantum as jqt 

10 

11from jax import jit, lax, vmap, debug 

12 

13import jax.numpy as jnp 

14 

15 

16class GKPQubit(BosonicQubit): 

17 """ 

18 GKP Qubit Class. 

19 """ 

20 

21 PARAMETERS = ["delta"] 

22 

23 name = "gkp" 

24 

25 def _params_validation(self): 

26 super()._params_validation() 

27 

28 if "delta" not in self.params: 

29 self.params["delta"] = 0.25 

30 

31 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi) 

32 s_delta = jnp.sinh(self.params["delta"] ** 2) 

33 self.params["epsilon"] = s_delta * self.params["l"] 

34 self.params["squeezing"] = jnp.log(self.params["delta"]) 

35 self.params["squeezing_dB"] = 20*jnp.log10(jnp.exp(jnp.abs(self.params["squeezing"]))) 

36 

37 def _gen_common_gates(self) -> None: 

38 """ 

39 Overriding this method to add additional common gates. 

40 """ 

41 super()._gen_common_gates() 

42 

43 # phase space 

44 self.common_gates["x"] = ( 

45 self.common_gates["a_dag"] + self.common_gates["a"] 

46 ) / jnp.sqrt(2.0) 

47 self.common_gates["p"] = ( 

48 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0) 

49 ) 

50 

51 # finite energy 

52 self.common_gates["E"] = jqt.expm( 

53 -(self.params["delta"] ** 2) 

54 * self.common_gates["a_dag"] 

55 @ self.common_gates["a"] 

56 ) 

57 self.common_gates["E_inv"] = jqt.expm( 

58 self.params["delta"] ** 2 

59 * self.common_gates["a_dag"] 

60 @ self.common_gates["a"] 

61 ) 

62 

63 # axis 

64 x_axis, z_axis = self._get_axis() 

65 y_axis = x_axis + z_axis 

66 

67 # gates 

68 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis) 

69 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis) 

70 Y_0 = 1.0j * X_0 @ Z_0 

71 self.common_gates["X_0"] = X_0 

72 self.common_gates["Z_0"] = Z_0 

73 self.common_gates["Y_0"] = Y_0 

74 self.common_gates["X"] = self._make_op_finite_energy(X_0) 

75 self.common_gates["Z"] = self._make_op_finite_energy(Z_0) 

76 self.common_gates["Y"] = self._make_op_finite_energy(Y_0) 

77 

78 # symmetric stabilizers and gates 

79 self.common_gates["Z_s_0"] = self._symmetrized_expm( 

80 1.0j * self.params["l"] / 2.0 * x_axis 

81 ) 

82 self.common_gates["S_x_0"] = self._symmetrized_expm( 

83 1.0j * self.params["l"] * z_axis 

84 ) 

85 self.common_gates["S_z_0"] = self._symmetrized_expm( 

86 1.0j * self.params["l"] * x_axis 

87 ) 

88 self.common_gates["S_y_0"] = self._symmetrized_expm( 

89 1.0j * self.params["l"] * y_axis 

90 ) 

91 

92 @staticmethod 

93 def _q_quadrature(q_points, n): 

94 q_points = q_points.T 

95 

96 F_0_init = jnp.ones_like(q_points) 

97 F_1_init = jnp.sqrt(2) * q_points 

98 

99 def scan_body(n, carry): 

100 F_0, F_1 = carry 

101 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt( 

102 (n - 1) / n) * F_0) 

103 

104 new_carry = (F_1, F_n) 

105 

106 return new_carry 

107 

108 initial_carry = (F_0_init, F_1_init) 

109 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])), 

110 scan_body, initial_carry) 

111 

112 q_quad = lax.select(n == 0, F_0_init, 

113 lax.select(n == 1, F_1_init, 

114 final_carry[1])) 

115 

116 q_quad = jnp.pi ** (-0.25) * lax.mul( 

117 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad) 

118 

119 return q_quad 

120 

121 @staticmethod 

122 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100): 

123 """ 

124 Args: 

125 mu: state index (0 or 1) 

126 

127 Returns: 

128 GKP basis state 

129 

130 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr> 

131 """ 

132 

133 # We choose the truncation of our series summation such that we 

134 # capture 6 sigmas of the envelope for a value of delta of 0.02. 

135 # delta * (truncat_series*2*sqrt(pi)) = 6 

136 

137 

138 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu) 

139 

140 def compute_pop(n): 

141 quadvals = GKPQubit._q_quadrature(q_points, n) 

142 return jnp.exp(-(delta ** 2) * n) * ( 

143 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0]) 

144 

145 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2)) 

146 

147 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype) 

148 

149 psi = psi.at[::2].set(psi_even) 

150 

151 psi = jqt.Qarray.create(jnp.array(psi)[:dim]) 

152 

153 return psi.unit() 

154 

155 

156 @staticmethod 

157 def _check_delta_warning(d): 

158 if d < 0.02: 

159 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.") 

160 

161 

162 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

163 """ 

164 Construct basis states |+-z>. 

165 """ 

166 

167 delta = self.params["delta"] 

168 dim = self.params["N"] 

169 

170 debug.callback(GKPQubit._check_delta_warning, delta) 

171 

172 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z, 

173 static_argnames=("dim",)) 

174 

175 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0) 

176 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1) 

177 

178 return plus_z, minus_z 

179 

180 # utils 

181 # ====================================================== 

182 def _get_axis(self): 

183 x_axis = self.common_gates["x"] 

184 z_axis = -self.common_gates["p"] 

185 return x_axis, z_axis 

186 

187 def _make_op_finite_energy(self, op): 

188 return self.common_gates["E"] @ op @ self.common_gates["E_inv"] 

189 

190 def _symmetrized_expm(self, op): 

191 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0 

192 

193 # gates 

194 # ====================================================== 

195 @property 

196 def x_U(self) -> jqt.Qarray: 

197 return self.common_gates["X"] 

198 

199 @property 

200 def y_U(self) -> jqt.Qarray: 

201 return self.common_gates["Y"] 

202 

203 @property 

204 def z_U(self) -> jqt.Qarray: 

205 return self.common_gates["Z"] 

206 

207 

208class RectangularGKPQubit(GKPQubit): 

209 

210 PARAMETERS = ["delta", "a"] 

211 

212 def _params_validation(self): 

213 super()._params_validation() 

214 if "a" not in self.params: 

215 self.params["a"] = 0.8 

216 

217 def _get_axis(self): 

218 a = self.params["a"] 

219 x_axis = a * self.common_gates["x"] 

220 z_axis = -1 / a * self.common_gates["p"] 

221 return x_axis, z_axis 

222 

223 

224class SquareGKPQubit(GKPQubit): 

225 

226 def _params_validation(self): 

227 super()._params_validation() 

228 self.params["a"] = 1.0 

229 

230 

231class HexagonalGKPQubit(GKPQubit): 

232 

233 def _get_axis(self): 

234 a = jnp.sqrt(2 / jnp.sqrt(3)) 

235 x_axis = a * ( 

236 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"] 

237 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"] 

238 ) 

239 z_axis = a * (-self.common_gates["p"]) 

240 return x_axis, z_axis 

241 

242 

243## Citations 

244 

245# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States 

246# Baptiste Royer, Shraddha Singh, and S. M. Girvin 

247# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020 

248 

249# Quantum error correction of a qubit encoded in grid states of an oscillator. 

250# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al. 

251# Nature 584, 368–372 (2020).