Coverage for jaxquantum/codes/gkp.py: 80%
79 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Cat Code Qubit
3"""
5from typing import Tuple
7from jaxquantum.codes.base import BosonicQubit
8import jaxquantum as jqt
10import jax.numpy as jnp
13class GKPQubit(BosonicQubit):
14 """
15 GKP Qubit Class.
16 """
18 name = "gkp"
20 def _params_validation(self):
21 super()._params_validation()
23 if "delta" not in self.params:
24 self.params["delta"] = 0.25
25 self.params["l"] = 2.0 * jnp.sqrt(jnp.pi)
26 s_delta = jnp.sinh(self.params["delta"] ** 2)
27 self.params["epsilon"] = s_delta * self.params["l"]
29 def _gen_common_gates(self) -> None:
30 """
31 Overriding this method to add additional common gates.
32 """
33 super()._gen_common_gates()
35 # phase space
36 self.common_gates["x"] = (
37 self.common_gates["a_dag"] + self.common_gates["a"]
38 ) / jnp.sqrt(2.0)
39 self.common_gates["p"] = (
40 1.0j * (self.common_gates["a_dag"] - self.common_gates["a"]) / jnp.sqrt(2.0)
41 )
43 # finite energy
44 self.common_gates["E"] = jqt.expm(
45 -(self.params["delta"] ** 2)
46 * self.common_gates["a_dag"]
47 @ self.common_gates["a"]
48 )
49 self.common_gates["E_inv"] = jqt.expm(
50 self.params["delta"] ** 2
51 * self.common_gates["a_dag"]
52 @ self.common_gates["a"]
53 )
55 # axis
56 x_axis, z_axis = self._get_axis()
57 y_axis = x_axis + z_axis
59 # gates
60 X_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * z_axis)
61 Z_0 = jqt.expm(1.0j * self.params["l"] / 2.0 * x_axis)
62 Y_0 = 1.0j * X_0 @ Z_0
63 self.common_gates["X_0"] = X_0
64 self.common_gates["Z_0"] = Z_0
65 self.common_gates["Y_0"] = Y_0
66 self.common_gates["X"] = self._make_op_finite_energy(X_0)
67 self.common_gates["Z"] = self._make_op_finite_energy(Z_0)
68 self.common_gates["Y"] = self._make_op_finite_energy(Y_0)
70 # symmetric stabilizers and gates
71 self.common_gates["Z_s_0"] = self._symmetrized_expm(
72 1.0j * self.params["l"] / 2.0 * x_axis
73 )
74 self.common_gates["S_x_0"] = self._symmetrized_expm(
75 1.0j * self.params["l"] * z_axis
76 )
77 self.common_gates["S_z_0"] = self._symmetrized_expm(
78 1.0j * self.params["l"] * x_axis
79 )
80 self.common_gates["S_y_0"] = self._symmetrized_expm(
81 1.0j * self.params["l"] * y_axis
82 )
84 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]:
85 """
86 Construct basis states |+-x>, |+-y>, |+-z>.
87 step 1: use ideal GKP stabilizers to find ideal GKP |+z> state
88 step 2: make ideal eigenvector finite energy
89 We want the groundstate of H = E H_0 E⁻¹.
90 So, we can begin by find the groundstate of H_0 -> |λ₀⟩
91 Then, we know that E|λ₀⟩ = |λ⟩ is the groundstate of H.
92 pf. H|λ⟩ = (E H_0 E⁻¹)(E|λ₀⟩) = E H_0 |λ₀⟩ = λ₀ (E|λ₀⟩) = λ₀|λ⟩
94 TODO (if necessary):
95 Alternatively, we could construct a hamiltonian using
96 finite energy stabilizers S_x, S_y, S_z, Z_s. However,
97 this would make H = - S_x - S_y - S_z - Z_s non-hermitian.
98 Currently, JAX does not support derivatives of jnp.linalg.eig,
99 while it does support derivatives of jnp.linalg.eigh.
100 Discussion: https://github.com/google/jax/issues/2748
101 """
103 # step 1: use ideal GKP stabilizers to find ideal GKP |+z> state
104 H_0 = (
105 -self.common_gates["S_x_0"]
106 - self.common_gates["S_y_0"]
107 - self.common_gates["S_z_0"]
108 - self.common_gates["Z_s_0"] # bosonic |+z> state
109 )
111 _, vecs = jnp.linalg.eigh(H_0.data)
112 gstate_ideal = jqt.Qarray.create(vecs[:, 0])
114 # step 2: make ideal eigenvector finite energy
115 gstate = self.common_gates["E"] @ gstate_ideal
117 plus_z = jqt.unit(gstate)
118 minus_z = jqt.unit(self.common_gates["X"] @ plus_z)
119 return plus_z, minus_z
121 # utils
122 # ======================================================
123 def _get_axis(self):
124 x_axis = self.common_gates["x"]
125 z_axis = -self.common_gates["p"]
126 return x_axis, z_axis
128 def _make_op_finite_energy(self, op):
129 return self.common_gates["E"] @ op @ self.common_gates["E_inv"]
131 def _symmetrized_expm(self, op):
132 return (jqt.expm(op) + jqt.expm(-1.0 * op)) / 2.0
134 # gates
135 # ======================================================
136 @property
137 def x_U(self) -> jqt.Qarray:
138 return self.common_gates["X"]
140 @property
141 def y_U(self) -> jqt.Qarray:
142 return self.common_gates["Y"]
144 @property
145 def z_U(self) -> jqt.Qarray:
146 return self.common_gates["Z"]
149class RectangularGKPQubit(GKPQubit):
150 def _params_validation(self):
151 super()._params_validation()
152 if "a" not in self.params:
153 self.params["a"] = 0.8
155 def _get_axis(self):
156 a = self.params["a"]
157 x_axis = a * self.common_gates["x"]
158 z_axis = -1 / a * self.common_gates["p"]
159 return x_axis, z_axis
162class SquareGKPQubit(GKPQubit):
163 def _params_validation(self):
164 super()._params_validation()
165 self.params["a"] = 1.0
168class HexagonalGKPQubit(GKPQubit):
169 def _get_axis(self):
170 a = jnp.sqrt(2 / jnp.sqrt(3))
171 x_axis = a * (
172 jnp.sin(jnp.pi / 3.0) * self.common_gates["x"]
173 + jnp.cos(jnp.pi / 3.0) * self.common_gates["p"]
174 )
175 z_axis = a * (-self.common_gates["p"])
176 return x_axis, z_axis
179## Citations
181# Stabilization of Finite-Energy Gottesman-Kitaev-Preskill States
182# Baptiste Royer, Shraddha Singh, and S. M. Girvin
183# Phys. Rev. Lett. 125, 260509 – Published 31 December 2020
185# Quantum error correction of a qubit encoded in grid states of an oscillator.
186# Campagne-Ibarcq, P., Eickbusch, A., Touzard, S. et al.
187# Nature 584, 368–372 (2020).