Coverage for jaxquantum/codes/gkp.py: 80%

79 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6 

7from jaxquantum.codes.base import BosonicQubit 

8import jaxquantum as jqt 

9 

10import jax.numpy as jnp 

11 

12 

13class GKPQubit(BosonicQubit): 

14 """ 

15 GKP Qubit Class. 

16 """ 

17 

18 name = "gkp" 

19 

20 def _params_validation(self): 

21 super()._params_validation() 

22 

23 if "delta" not in self.params: 

24 self.params["delta"] = 0.25 

25 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi) 

26 s_delta = jnp.sinh(self.params["delta"] ** 2) 

27 self.params["epsilon"] = s_delta * self.params["l"] 

28 

29 def _gen_common_gates(self) -> None: 

30 """ 

31 Overriding this method to add additional common gates. 

32 """ 

33 super()._gen_common_gates() 

34 

35 # phase space 

36 self.common_gates["x"] = ( 

37 self.common_gates["a_dag"] + self.common_gates["a"] 

38 ) / jnp.sqrt(2.0) 

39 self.common_gates["p"] = ( 

40 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0) 

41 ) 

42 

43 # finite energy 

44 self.common_gates["E"] = jqt.expm( 

45 -(self.params["delta"] ** 2) 

46 * self.common_gates["a_dag"] 

47 @ self.common_gates["a"] 

48 ) 

49 self.common_gates["E_inv"] = jqt.expm( 

50 self.params["delta"] ** 2 

51 * self.common_gates["a_dag"] 

52 @ self.common_gates["a"] 

53 ) 

54 

55 # axis 

56 x_axis, z_axis = self._get_axis() 

57 y_axis = x_axis + z_axis 

58 

59 # gates 

60 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis) 

61 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis) 

62 Y_0 = 1.0j * X_0 @ Z_0 

63 self.common_gates["X_0"] = X_0 

64 self.common_gates["Z_0"] = Z_0 

65 self.common_gates["Y_0"] = Y_0 

66 self.common_gates["X"] = self._make_op_finite_energy(X_0) 

67 self.common_gates["Z"] = self._make_op_finite_energy(Z_0) 

68 self.common_gates["Y"] = self._make_op_finite_energy(Y_0) 

69 

70 # symmetric stabilizers and gates 

71 self.common_gates["Z_s_0"] = self._symmetrized_expm( 

72 1.0j * self.params["l"] / 2.0 * x_axis 

73 ) 

74 self.common_gates["S_x_0"] = self._symmetrized_expm( 

75 1.0j * self.params["l"] * z_axis 

76 ) 

77 self.common_gates["S_z_0"] = self._symmetrized_expm( 

78 1.0j * self.params["l"] * x_axis 

79 ) 

80 self.common_gates["S_y_0"] = self._symmetrized_expm( 

81 1.0j * self.params["l"] * y_axis 

82 ) 

83 

84 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

85 """ 

86 Construct basis states |+-x>, |+-y>, |+-z>. 

87 step 1: use ideal GKP stabilizers to find ideal GKP |+z> state 

88 step 2: make ideal eigenvector finite energy 

89 We want the groundstate of H = E H_0 E⁻¹. 

90 So, we can begin by find the groundstate of H_0 -> |λ₀⟩ 

91 Then, we know that E|λ₀⟩ = |λ⟩ is the groundstate of H. 

92 pf. H|λ⟩ = (E H_0 E⁻¹)(E|λ₀⟩) = E H_0 |λ₀⟩ = λ₀ (E|λ₀⟩) = λ₀|λ⟩ 

93 

94 TODO (if necessary): 

95 Alternatively, we could construct a hamiltonian using 

96 finite energy stabilizers S_x, S_y, S_z, Z_s. However, 

97 this would make H = - S_x - S_y - S_z - Z_s non-hermitian. 

98 Currently, JAX does not support derivatives of jnp.linalg.eig, 

99 while it does support derivatives of jnp.linalg.eigh. 

100 Discussion: https://github.com/google/jax/issues/2748 

101 """ 

102 

103 # step 1: use ideal GKP stabilizers to find ideal GKP |+z> state 

104 H_0 = ( 

105 -self.common_gates["S_x_0"] 

106 - self.common_gates["S_y_0"] 

107 - self.common_gates["S_z_0"] 

108 - self.common_gates["Z_s_0"] # bosonic |+z> state 

109 ) 

110 

111 _, vecs = jnp.linalg.eigh(H_0.data) 

112 gstate_ideal = jqt.Qarray.create(vecs[:, 0]) 

113 

114 # step 2: make ideal eigenvector finite energy 

115 gstate = self.common_gates["E"] @ gstate_ideal 

116 

117 plus_z = jqt.unit(gstate) 

118 minus_z = jqt.unit(self.common_gates["X"] @ plus_z) 

119 return plus_z, minus_z 

120 

121 # utils 

122 # ====================================================== 

123 def _get_axis(self): 

124 x_axis = self.common_gates["x"] 

125 z_axis = -self.common_gates["p"] 

126 return x_axis, z_axis 

127 

128 def _make_op_finite_energy(self, op): 

129 return self.common_gates["E"] @ op @ self.common_gates["E_inv"] 

130 

131 def _symmetrized_expm(self, op): 

132 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0 

133 

134 # gates 

135 # ====================================================== 

136 @property 

137 def x_U(self) -> jqt.Qarray: 

138 return self.common_gates["X"] 

139 

140 @property 

141 def y_U(self) -> jqt.Qarray: 

142 return self.common_gates["Y"] 

143 

144 @property 

145 def z_U(self) -> jqt.Qarray: 

146 return self.common_gates["Z"] 

147 

148 

149class RectangularGKPQubit(GKPQubit): 

150 def _params_validation(self): 

151 super()._params_validation() 

152 if "a" not in self.params: 

153 self.params["a"] = 0.8 

154 

155 def _get_axis(self): 

156 a = self.params["a"] 

157 x_axis = a * self.common_gates["x"] 

158 z_axis = -1 / a * self.common_gates["p"] 

159 return x_axis, z_axis 

160 

161 

162class SquareGKPQubit(GKPQubit): 

163 def _params_validation(self): 

164 super()._params_validation() 

165 self.params["a"] = 1.0 

166 

167 

168class HexagonalGKPQubit(GKPQubit): 

169 def _get_axis(self): 

170 a = jnp.sqrt(2 / jnp.sqrt(3)) 

171 x_axis = a * ( 

172 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"] 

173 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"] 

174 ) 

175 z_axis = a * (-self.common_gates["p"]) 

176 return x_axis, z_axis 

177 

178 

179## Citations 

180 

181# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States 

182# Baptiste Royer, Shraddha Singh, and S. M. Girvin 

183# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020 

184 

185# Quantum error correction of a qubit encoded in grid states of an oscillator. 

186# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al. 

187# Nature 584, 368–372 (2020).