Coverage for jaxquantum / codes / gkp.py: 99%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-01-28 21:05 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6import warnings 

7 

8from jaxquantum.codes.base import BosonicQubit 

9import jaxquantum as jqt 

10 

11from jax import jit, lax, vmap, debug 

12 

13import jax.numpy as jnp 

14 

15 

16class GKPQubit(BosonicQubit): 

17 """ 

18 GKP Qubit Class. 

19 """ 

20 

21 name = "gkp" 

22 

23 def _params_validation(self): 

24 super()._params_validation() 

25 

26 if "delta" not in self.params: 

27 self.params["delta"] = 0.25 

28 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi) 

29 s_delta = jnp.sinh(self.params["delta"] ** 2) 

30 self.params["epsilon"] = s_delta * self.params["l"] 

31 self.params["squeezing"] = jnp.log(self.params["delta"]) 

32 self.params["squeezing_dB"] = 20*jnp.log10(jnp.exp(jnp.abs(self.params["squeezing"]))) 

33 

34 def _gen_common_gates(self) -> None: 

35 """ 

36 Overriding this method to add additional common gates. 

37 """ 

38 super()._gen_common_gates() 

39 

40 # phase space 

41 self.common_gates["x"] = ( 

42 self.common_gates["a_dag"] + self.common_gates["a"] 

43 ) / jnp.sqrt(2.0) 

44 self.common_gates["p"] = ( 

45 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0) 

46 ) 

47 

48 # finite energy 

49 self.common_gates["E"] = jqt.expm( 

50 -(self.params["delta"] ** 2) 

51 * self.common_gates["a_dag"] 

52 @ self.common_gates["a"] 

53 ) 

54 self.common_gates["E_inv"] = jqt.expm( 

55 self.params["delta"] ** 2 

56 * self.common_gates["a_dag"] 

57 @ self.common_gates["a"] 

58 ) 

59 

60 # axis 

61 x_axis, z_axis = self._get_axis() 

62 y_axis = x_axis + z_axis 

63 

64 # gates 

65 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis) 

66 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis) 

67 Y_0 = 1.0j * X_0 @ Z_0 

68 self.common_gates["X_0"] = X_0 

69 self.common_gates["Z_0"] = Z_0 

70 self.common_gates["Y_0"] = Y_0 

71 self.common_gates["X"] = self._make_op_finite_energy(X_0) 

72 self.common_gates["Z"] = self._make_op_finite_energy(Z_0) 

73 self.common_gates["Y"] = self._make_op_finite_energy(Y_0) 

74 

75 # symmetric stabilizers and gates 

76 self.common_gates["Z_s_0"] = self._symmetrized_expm( 

77 1.0j * self.params["l"] / 2.0 * x_axis 

78 ) 

79 self.common_gates["S_x_0"] = self._symmetrized_expm( 

80 1.0j * self.params["l"] * z_axis 

81 ) 

82 self.common_gates["S_z_0"] = self._symmetrized_expm( 

83 1.0j * self.params["l"] * x_axis 

84 ) 

85 self.common_gates["S_y_0"] = self._symmetrized_expm( 

86 1.0j * self.params["l"] * y_axis 

87 ) 

88 

89 @staticmethod 

90 def _q_quadrature(q_points, n): 

91 q_points = q_points.T 

92 

93 F_0_init = jnp.ones_like(q_points) 

94 F_1_init = jnp.sqrt(2) * q_points 

95 

96 def scan_body(n, carry): 

97 F_0, F_1 = carry 

98 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt( 

99 (n - 1) / n) * F_0) 

100 

101 new_carry = (F_1, F_n) 

102 

103 return new_carry 

104 

105 initial_carry = (F_0_init, F_1_init) 

106 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])), 

107 scan_body, initial_carry) 

108 

109 q_quad = lax.select(n == 0, F_0_init, 

110 lax.select(n == 1, F_1_init, 

111 final_carry[1])) 

112 

113 q_quad = jnp.pi ** (-0.25) * lax.mul( 

114 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad) 

115 

116 return q_quad 

117 

118 @staticmethod 

119 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100): 

120 """ 

121 Args: 

122 mu: state index (0 or 1) 

123 

124 Returns: 

125 GKP basis state 

126 

127 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr> 

128 """ 

129 

130 # We choose the truncation of our series summation such that we 

131 # capture 6 sigmas of the envelope for a value of delta of 0.02. 

132 # delta * (truncat_series*2*sqrt(pi)) = 6 

133 

134 

135 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu) 

136 

137 def compute_pop(n): 

138 quadvals = GKPQubit._q_quadrature(q_points, n) 

139 return jnp.exp(-(delta ** 2) * n) * ( 

140 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0]) 

141 

142 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2)) 

143 

144 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype) 

145 

146 psi = psi.at[::2].set(psi_even) 

147 

148 psi = jqt.Qarray.create(jnp.array(psi)[:dim]) 

149 

150 return psi.unit() 

151 

152 

153 @staticmethod 

154 def _check_delta_warning(d): 

155 if d < 0.02: 

156 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.") 

157 

158 

159 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

160 """ 

161 Construct basis states |+-z>. 

162 """ 

163 

164 delta = self.params["delta"] 

165 dim = self.params["N"] 

166 

167 debug.callback(GKPQubit._check_delta_warning, delta) 

168 

169 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z, 

170 static_argnames=("dim",)) 

171 

172 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0) 

173 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1) 

174 

175 return plus_z, minus_z 

176 

177 # utils 

178 # ====================================================== 

179 def _get_axis(self): 

180 x_axis = self.common_gates["x"] 

181 z_axis = -self.common_gates["p"] 

182 return x_axis, z_axis 

183 

184 def _make_op_finite_energy(self, op): 

185 return self.common_gates["E"] @ op @ self.common_gates["E_inv"] 

186 

187 def _symmetrized_expm(self, op): 

188 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0 

189 

190 # gates 

191 # ====================================================== 

192 @property 

193 def x_U(self) -> jqt.Qarray: 

194 return self.common_gates["X"] 

195 

196 @property 

197 def y_U(self) -> jqt.Qarray: 

198 return self.common_gates["Y"] 

199 

200 @property 

201 def z_U(self) -> jqt.Qarray: 

202 return self.common_gates["Z"] 

203 

204 

205class RectangularGKPQubit(GKPQubit): 

206 def _params_validation(self): 

207 super()._params_validation() 

208 if "a" not in self.params: 

209 self.params["a"] = 0.8 

210 

211 def _get_axis(self): 

212 a = self.params["a"] 

213 x_axis = a * self.common_gates["x"] 

214 z_axis = -1 / a * self.common_gates["p"] 

215 return x_axis, z_axis 

216 

217 

218class SquareGKPQubit(GKPQubit): 

219 def _params_validation(self): 

220 super()._params_validation() 

221 self.params["a"] = 1.0 

222 

223 

224class HexagonalGKPQubit(GKPQubit): 

225 def _get_axis(self): 

226 a = jnp.sqrt(2 / jnp.sqrt(3)) 

227 x_axis = a * ( 

228 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"] 

229 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"] 

230 ) 

231 z_axis = a * (-self.common_gates["p"]) 

232 return x_axis, z_axis 

233 

234 

235## Citations 

236 

237# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States 

238# Baptiste Royer, Shraddha Singh, and S. M. Girvin 

239# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020 

240 

241# Quantum error correction of a qubit encoded in grid states of an oscillator. 

242# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al. 

243# Nature 584, 368–372 (2020).