Coverage for jaxquantum/codes/gkp.py: 99%

108 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6import warnings 

7 

8from jaxquantum.codes.base import BosonicQubit 

9import jaxquantum as jqt 

10 

11from jax import jit, lax, vmap 

12 

13import jax.numpy as jnp 

14 

15 

16class GKPQubit(BosonicQubit): 

17 """ 

18 GKP Qubit Class. 

19 """ 

20 

21 name = "gkp" 

22 

23 def _params_validation(self): 

24 super()._params_validation() 

25 

26 if "delta" not in self.params: 

27 self.params["delta"] = 0.25 

28 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi) 

29 s_delta = jnp.sinh(self.params["delta"] ** 2) 

30 self.params["epsilon"] = s_delta * self.params["l"] 

31 

32 def _gen_common_gates(self) -> None: 

33 """ 

34 Overriding this method to add additional common gates. 

35 """ 

36 super()._gen_common_gates() 

37 

38 # phase space 

39 self.common_gates["x"] = ( 

40 self.common_gates["a_dag"] + self.common_gates["a"] 

41 ) / jnp.sqrt(2.0) 

42 self.common_gates["p"] = ( 

43 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0) 

44 ) 

45 

46 # finite energy 

47 self.common_gates["E"] = jqt.expm( 

48 -(self.params["delta"] ** 2) 

49 * self.common_gates["a_dag"] 

50 @ self.common_gates["a"] 

51 ) 

52 self.common_gates["E_inv"] = jqt.expm( 

53 self.params["delta"] ** 2 

54 * self.common_gates["a_dag"] 

55 @ self.common_gates["a"] 

56 ) 

57 

58 # axis 

59 x_axis, z_axis = self._get_axis() 

60 y_axis = x_axis + z_axis 

61 

62 # gates 

63 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis) 

64 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis) 

65 Y_0 = 1.0j * X_0 @ Z_0 

66 self.common_gates["X_0"] = X_0 

67 self.common_gates["Z_0"] = Z_0 

68 self.common_gates["Y_0"] = Y_0 

69 self.common_gates["X"] = self._make_op_finite_energy(X_0) 

70 self.common_gates["Z"] = self._make_op_finite_energy(Z_0) 

71 self.common_gates["Y"] = self._make_op_finite_energy(Y_0) 

72 

73 # symmetric stabilizers and gates 

74 self.common_gates["Z_s_0"] = self._symmetrized_expm( 

75 1.0j * self.params["l"] / 2.0 * x_axis 

76 ) 

77 self.common_gates["S_x_0"] = self._symmetrized_expm( 

78 1.0j * self.params["l"] * z_axis 

79 ) 

80 self.common_gates["S_z_0"] = self._symmetrized_expm( 

81 1.0j * self.params["l"] * x_axis 

82 ) 

83 self.common_gates["S_y_0"] = self._symmetrized_expm( 

84 1.0j * self.params["l"] * y_axis 

85 ) 

86 

87 @staticmethod 

88 def _q_quadrature(q_points, n): 

89 q_points = q_points.T 

90 

91 F_0_init = jnp.ones_like(q_points) 

92 F_1_init = jnp.sqrt(2) * q_points 

93 

94 def scan_body(n, carry): 

95 F_0, F_1 = carry 

96 F_n = (jnp.sqrt(2 / n) * lax.mul(q_points, F_1) - jnp.sqrt( 

97 (n - 1) / n) * F_0) 

98 

99 new_carry = (F_1, F_n) 

100 

101 return new_carry 

102 

103 initial_carry = (F_0_init, F_1_init) 

104 final_carry = lax.fori_loop(2, jnp.max(jnp.array([n + 1, 2])), 

105 scan_body, initial_carry) 

106 

107 q_quad = lax.select(n == 0, F_0_init, 

108 lax.select(n == 1, F_1_init, 

109 final_carry[1])) 

110 

111 q_quad = jnp.pi ** (-0.25) * lax.mul( 

112 jnp.exp(-lax.pow(q_points, 2) / 2), q_quad) 

113 

114 return q_quad 

115 

116 @staticmethod 

117 def _compute_gkp_basis_z(delta, dim, mu, series_trunc=100): 

118 """ 

119 Args: 

120 mu: state index (0 or 1) 

121 

122 Returns: 

123 GKP basis state 

124 

125 Adapted from code by Lev-Arcady Sellem <lev-arcady.sellem@inria.fr> 

126 """ 

127 

128 # We choose the truncation of our series summation such that we 

129 # capture 6 sigmas of the envelope for a value of delta of 0.02. 

130 # delta * (truncat_series*2*sqrt(pi)) = 6 

131 

132 

133 

134 

135 q_points = jnp.sqrt(jnp.pi) * (2 * jnp.arange(series_trunc) + mu) 

136 

137 def compute_pop(n): 

138 quadvals = GKPQubit._q_quadrature(q_points, n) 

139 return jnp.exp(-(delta ** 2) * n) * ( 

140 2 * jnp.sum(quadvals) - (1 - mu) * quadvals[0]) 

141 

142 psi_even = vmap(compute_pop)(jnp.arange(0, dim, 2)) 

143 

144 psi = jnp.zeros(2 * psi_even.size, dtype=psi_even.dtype) 

145 

146 psi = psi.at[::2].set(psi_even) 

147 

148 psi = jqt.Qarray.create(jnp.array(psi)) 

149 

150 return psi.unit() 

151 

152 

153 

154 

155 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

156 """ 

157 Construct basis states |+-z>. 

158 """ 

159 

160 delta = self.params["delta"] 

161 dim = self.params["N"] 

162 

163 if delta<0.02: 

164 warnings.warn("State preparation with delta values lower than 0.02 might lead to loss of accuracy.") 

165 

166 jitted_compute_gkp_basis_z = jit(self._compute_gkp_basis_z, 

167 static_argnames=("dim",)) 

168 

169 plus_z = jitted_compute_gkp_basis_z(delta, dim, 0) 

170 minus_z = jitted_compute_gkp_basis_z(delta, dim, 1) 

171 

172 return plus_z, minus_z 

173 

174 # utils 

175 # ====================================================== 

176 def _get_axis(self): 

177 x_axis = self.common_gates["x"] 

178 z_axis = -self.common_gates["p"] 

179 return x_axis, z_axis 

180 

181 def _make_op_finite_energy(self, op): 

182 return self.common_gates["E"] @ op @ self.common_gates["E_inv"] 

183 

184 def _symmetrized_expm(self, op): 

185 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0 

186 

187 # gates 

188 # ====================================================== 

189 @property 

190 def x_U(self) -> jqt.Qarray: 

191 return self.common_gates["X"] 

192 

193 @property 

194 def y_U(self) -> jqt.Qarray: 

195 return self.common_gates["Y"] 

196 

197 @property 

198 def z_U(self) -> jqt.Qarray: 

199 return self.common_gates["Z"] 

200 

201 

202class RectangularGKPQubit(GKPQubit): 

203 def _params_validation(self): 

204 super()._params_validation() 

205 if "a" not in self.params: 

206 self.params["a"] = 0.8 

207 

208 def _get_axis(self): 

209 a = self.params["a"] 

210 x_axis = a * self.common_gates["x"] 

211 z_axis = -1 / a * self.common_gates["p"] 

212 return x_axis, z_axis 

213 

214 

215class SquareGKPQubit(GKPQubit): 

216 def _params_validation(self): 

217 super()._params_validation() 

218 self.params["a"] = 1.0 

219 

220 

221class HexagonalGKPQubit(GKPQubit): 

222 def _get_axis(self): 

223 a = jnp.sqrt(2 / jnp.sqrt(3)) 

224 x_axis = a * ( 

225 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"] 

226 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"] 

227 ) 

228 z_axis = a * (-self.common_gates["p"]) 

229 return x_axis, z_axis 

230 

231 

232## Citations 

233 

234# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States 

235# Baptiste Royer, Shraddha Singh, and S. M. Girvin 

236# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020 

237 

238# Quantum error correction of a qubit encoded in grid states of an oscillator. 

239# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al. 

240# Nature 584, 368–372 (2020).