Coverage for jaxquantum/codes/base.py: 51%

139 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Base Bosonic Qubit Class 

3""" 

4 

5from typing import Dict, Optional, Tuple 

6from abc import abstractmethod, ABCMeta 

7 

8from jaxquantum.utils.utils import device_put_params 

9import jaxquantum as jqt 

10 

11from jax import config 

12import numpy as np 

13import jax.numpy as jnp 

14import matplotlib.pyplot as plt 

15 

16config.update("jax_enable_x64", True) 

17 

18 

19class BosonicQubit(metaclass=ABCMeta): 

20 """ 

21 Base class for Bosonic Qubits. 

22 """ 

23 

24 name = "bqubit" 

25 

26 @property 

27 def _non_device_params(self): 

28 """ 

29 Can be overriden in child classes. 

30 """ 

31 return ["N"] 

32 

33 def __init__(self, params: Optional[Dict[str, float]] = None, name: str = None): 

34 if name is not None: 

35 self.name = name 

36 

37 self.params = params if params else {} 

38 self._params_validation() 

39 

40 self.params = device_put_params(self.params, self._non_device_params) 

41 

42 self.common_gates: Dict[str, jqt.Qarray] = {} 

43 self._gen_common_gates() 

44 

45 self.wigner_pts = jnp.linspace(-4.5, 4.5, 61) 

46 

47 self.basis = self._get_basis_states() 

48 

49 for basis_state in ["+x", "-x", "+y", "-y", "+z", "-z"]: 

50 assert basis_state in self.basis, ( 

51 f"Please set the {basis_state} basis state." 

52 ) 

53 

54 def _params_validation(self): 

55 """ 

56 Override this method to add additional validation to params. 

57 

58 E.g. 

59 if "N" not in self.params: 

60 self.params["N"] = 50 

61 """ 

62 if "N" not in self.params: 

63 self.params["N"] = 50 

64 

65 def _gen_common_gates(self): 

66 """ 

67 Override this method to add additional common gates. 

68 

69 E.g. 

70 if "N" not in self.params: 

71 self.params["N"] = 50 

72 """ 

73 N = self.params["N"] 

74 self.common_gates["a_dag"] = jqt.create(N) 

75 self.common_gates["a"] = jqt.destroy(N) 

76 

77 @abstractmethod 

78 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

79 """ 

80 Returns: 

81 plus_z (jqt.Qarray), minus_z (jqt.Qarray): z basis states 

82 """ 

83 

84 def _get_basis_states(self) -> Dict[str, jqt.Qarray]: 

85 """ 

86 Construct basis states |+-x>, |+-y>, |+-z> 

87 """ 

88 plus_z, minus_z = self._get_basis_z() 

89 return self._gen_basis_states_from_z(plus_z, minus_z) 

90 

91 def _gen_basis_states_from_z( 

92 self, plus_z: jqt.Qarray, minus_z: jqt.Qarray 

93 ) -> Dict[str, jqt.Qarray]: 

94 """ 

95 Construct basis states |+-x>, |+-y>, |+-z> from |+-z> 

96 """ 

97 basis: Dict[str, jqt.Qarray] = {} 

98 

99 # import to make sure that each basis state is a column vec 

100 # otherwise, transposing a 1D vector will do nothing 

101 

102 basis["+z"] = plus_z 

103 basis["-z"] = minus_z 

104 

105 basis["+x"] = jqt.unit(basis["+z"] + basis["-z"]) 

106 basis["-x"] = jqt.unit(basis["+z"] - basis["-z"]) 

107 basis["+y"] = jqt.unit(basis["+z"] + 1j * basis["-z"]) 

108 basis["-y"] = jqt.unit(basis["+z"] - 1j * basis["-z"]) 

109 return basis 

110 

111 def jqt2qt(self, state): 

112 return jqt.jqt2qt(state) 

113 

114 # gates 

115 # ====================================================== 

116 # @abstractmethod 

117 # def stabilize(self) -> None: 

118 # """ 

119 # Stabilizing/measuring syndromes. 

120 # """ 

121 

122 @property 

123 def x_U(self) -> jqt.Qarray: 

124 """ 

125 Logical X unitary gate. 

126 """ 

127 return self._gen_pauli_U("x") 

128 

129 @property 

130 def x_H(self) -> Optional[jqt.Qarray]: 

131 """ 

132 Logical X hamiltonian. 

133 """ 

134 return None 

135 

136 @property 

137 def y_U(self) -> jqt.Qarray: 

138 """ 

139 Logical Y unitary gate. 

140 """ 

141 return self._gen_pauli_U("y") 

142 

143 @property 

144 def y_H(self) -> Optional[jqt.Qarray]: 

145 """ 

146 Logical Y hamiltonian. 

147 """ 

148 return None 

149 

150 @property 

151 def z_U(self) -> jqt.Qarray: 

152 """ 

153 Logical Z unitary gate. 

154 """ 

155 return self._gen_pauli_U("z") 

156 

157 @property 

158 def z_H(self) -> Optional[jqt.Qarray]: 

159 """ 

160 Logical Z hamiltonian. 

161 """ 

162 return None 

163 

164 @property 

165 def h_H(self) -> Optional[jqt.Qarray]: 

166 """ 

167 Logical Hadamard hamiltonian. 

168 """ 

169 return None 

170 

171 @property 

172 def h_U(self) -> jqt.Qarray: 

173 """ 

174 Logical Hadamard unitary gate. 

175 """ 

176 return ( 

177 self.basis["+x"] @ self.basis["+z"].dag() 

178 + self.basis["-x"] @ self.basis["-z"].dag() 

179 ) 

180 

181 def _gen_pauli_U(self, basis_state: str) -> jqt.Qarray: 

182 """ 

183 Generates unitary for Pauli X, Y, Z. 

184 

185 Args: 

186 basis_state (str): "x", "y", "z" 

187 

188 Returns: 

189 U (jqt.Qarray): Pauli unitary 

190 """ 

191 H = getattr(self, basis_state + "_H") 

192 if H is not None: 

193 return jqt.expm(1.0j * H) 

194 

195 gate = ( 

196 self.basis["+" + basis_state] @ self.basis["+" + basis_state].dag() 

197 - self.basis["-" + basis_state] @ self.basis["-" + basis_state].dag() 

198 ) 

199 

200 return gate 

201 

202 @property 

203 def projector(self): 

204 return ( 

205 self.basis["+z"] @ self.basis["+z"].dag() 

206 + self.basis["-z"] @ self.basis["-z"].dag() 

207 ) 

208 

209 @property 

210 def maximally_mixed_state(self): 

211 return (1 / 2.0) * self.projector() 

212 

213 # Plotting 

214 # ====================================================== 

215 def _prepare_state_plot(self, state): 

216 """ 

217 Can be overriden. 

218 

219 E.g. in the case of cavity x transmon system 

220 return qt.ptrace(state, 0) 

221 """ 

222 return state 

223 

224 def plot(self, state, ax=None, qp_type=jqt.WIGNER, **kwargs) -> None: 

225 if ax is None: 

226 fig, ax = plt.subplots(1, figsize=(4, 3), dpi=200) 

227 fig = ax.get_figure() 

228 

229 if qp_type == jqt.WIGNER: 

230 vmin = -1 

231 vmax = 1 

232 elif qp_type == jqt.QFUNC: 

233 vmin = 0 

234 vmax = 1 

235 

236 w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs) 

237 

238 ax.set_title(qp_type.capitalize() + " Quasi-Probability Dist.") 

239 ticks = np.linspace(vmin, vmax, 5) 

240 fig.colorbar(w_plt, ax=ax, ticks=ticks) 

241 ax.set_xlabel(r"Re$(\alpha)$") 

242 ax.set_ylabel(r"Im$(\alpha)$") 

243 fig.tight_layout() 

244 

245 plt.show() 

246 

247 def _plot_single(self, state, ax=None, contour=True, qp_type=jqt.WIGNER): 

248 """ 

249 Assumes state has same dims as initial_state. 

250 """ 

251 state = self.jqt2qt(state) 

252 

253 if ax is None: 

254 _, ax = plt.subplots(1, figsize=(4, 3), dpi=200) 

255 

256 return jqt.plot_qp( 

257 state, self.wigner_pts, ax=ax, contour=contour, qp_type=qp_type 

258 ) 

259 

260 def plot_code_states(self, qp_type: str = jqt.WIGNER, **kwargs): 

261 """ 

262 Plot |±x⟩, |±y⟩, |±z⟩ code states. 

263 

264 Args: 

265 qp_type (str): 

266 WIGNER or QFUNC 

267 

268 Return: 

269 axs: Axes 

270 """ 

271 fig, axs = plt.subplots(2, 3, figsize=(12, 6), dpi=200) 

272 if qp_type == jqt.WIGNER: 

273 cbar_title = r"$\frac{\pi}{2} W(\alpha)$" 

274 vmin = -1 

275 vmax = 1 

276 elif qp_type == jqt.QFUNC: 

277 cbar_title = r"$\pi Q(\alpha)$" 

278 vmin = 0 

279 vmax = 1 

280 

281 for i, label in enumerate(["+z", "+x", "+y", "-z", "-x", "-y"]): 

282 state = self._prepare_state_plot(self.basis[label]) 

283 pos = (i // 3, i % 3) 

284 ax = axs[pos] 

285 _, w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs) 

286 ax.set_title(f"|{label}" + r"$\rangle$") 

287 ax.set_xlabel(r"Re[$\alpha$]") 

288 ax.set_ylabel(r"Im[$\alpha$]") 

289 

290 fig.suptitle(self.name) 

291 fig.tight_layout() 

292 fig.subplots_adjust(right=0.8, hspace=0.2, wspace=0.2) 

293 fig.align_xlabels(axs) 

294 fig.align_ylabels(axs) 

295 cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) 

296 

297 ticks = np.linspace(vmin, vmax, 5) 

298 fig.colorbar(w_plt, cax=cbar_ax, ticks=ticks) 

299 

300 cbar_ax.set_title(cbar_title, pad=20) 

301 fig.tight_layout() 

302 plt.show()