Coverage for jaxquantum / codes / base.py: 85%

130 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 21:51 +0000

1""" 

2Base Bosonic Qubit Class 

3""" 

4 

5from typing import Dict, Optional, Tuple 

6from abc import abstractmethod, ABCMeta 

7 

8from jaxquantum.utils.utils import device_put_params 

9import jaxquantum as jqt 

10 

11from jax import config 

12import numpy as np 

13import jax.numpy as jnp 

14import matplotlib.pyplot as plt 

15 

16config.update("jax_enable_x64", True) 

17 

18 

19class BosonicQubit(metaclass=ABCMeta): 

20 """ 

21 Base class for Bosonic Qubits. 

22 """ 

23 

24 BASE_PARAMETERS = ["N"] 

25 PARAMETERS = [] 

26 

27 name = "bqubit" 

28 

29 @property 

30 def _non_device_params(self): 

31 """ 

32 Can be overriden in child classes. 

33 """ 

34 return ["N"] 

35 

36 def __init__(self, params: Optional[Dict[str, float]] = None, name: str = None): 

37 if name is not None: 

38 self.name = name 

39 

40 self.params = params if params else {} 

41 self._params_validation() 

42 

43 self.params = device_put_params(self.params, self._non_device_params) 

44 

45 self.common_gates: Dict[str, jqt.Qarray] = {} 

46 self._gen_common_gates() 

47 

48 self.wigner_pts = jnp.linspace(-4.5, 4.5, 61) 

49 

50 self.basis = self._get_basis_states() 

51 

52 for basis_state in ["+x", "-x", "+y", "-y", "+z", "-z"]: 

53 assert basis_state in self.basis, ( 

54 f"Please set the {basis_state} basis state." 

55 ) 

56 

57 def _params_validation(self): 

58 """ 

59 Override this method to add additional validation to params. 

60 

61 E.g. 

62 if "N" not in self.params: 

63 self.params["N"] = 50 

64 """ 

65 

66 for key in self.params: 

67 if key not in self.BASE_PARAMETERS + self.PARAMETERS: 

68 raise ValueError( 

69 f"Invalid parameter {key}. Allowed parameters are {self.BASE_PARAMETERS + self.PARAMETERS}" 

70 ) 

71 

72 if "N" not in self.params: 

73 self.params["N"] = 50 

74 

75 def _gen_common_gates(self): 

76 """ 

77 Override this method to add additional common gates. 

78 

79 E.g. 

80 if "N" not in self.params: 

81 self.params["N"] = 50 

82 """ 

83 N = self.params["N"] 

84 self.common_gates["a_dag"] = jqt.create(N) 

85 self.common_gates["a"] = jqt.destroy(N) 

86 

87 @abstractmethod 

88 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

89 """ 

90 Returns: 

91 plus_z (jqt.Qarray), minus_z (jqt.Qarray): z basis states 

92 """ 

93 

94 def _get_basis_states(self) -> Dict[str, jqt.Qarray]: 

95 """ 

96 Construct basis states |+-x>, |+-y>, |+-z> 

97 """ 

98 plus_z, minus_z = self._get_basis_z() 

99 return self._gen_basis_states_from_z(plus_z, minus_z) 

100 

101 def _gen_basis_states_from_z( 

102 self, plus_z: jqt.Qarray, minus_z: jqt.Qarray 

103 ) -> Dict[str, jqt.Qarray]: 

104 """ 

105 Construct basis states |+-x>, |+-y>, |+-z> from |+-z> 

106 """ 

107 basis: Dict[str, jqt.Qarray] = {} 

108 

109 # import to make sure that each basis state is a column vec 

110 # otherwise, transposing a 1D vector will do nothing 

111 

112 basis["+z"] = plus_z 

113 basis["-z"] = minus_z 

114 

115 basis["+x"] = jqt.unit(basis["+z"] + basis["-z"]) 

116 basis["-x"] = jqt.unit(basis["+z"] - basis["-z"]) 

117 basis["+y"] = jqt.unit(basis["+z"] + 1j * basis["-z"]) 

118 basis["-y"] = jqt.unit(basis["+z"] - 1j * basis["-z"]) 

119 return basis 

120 

121 def jqt2qt(self, state): 

122 return jqt.jqt2qt(state) 

123 

124 # gates 

125 # ====================================================== 

126 # @abstractmethod 

127 # def stabilize(self) -> None: 

128 # """ 

129 # Stabilizing/measuring syndromes. 

130 # """ 

131 

132 @property 

133 def x_U(self) -> jqt.Qarray: 

134 """ 

135 Logical X unitary gate. 

136 """ 

137 return self._gen_pauli_U("x") 

138 

139 @property 

140 def x_H(self) -> Optional[jqt.Qarray]: 

141 """ 

142 Logical X hamiltonian. 

143 """ 

144 return None 

145 

146 @property 

147 def y_U(self) -> jqt.Qarray: 

148 """ 

149 Logical Y unitary gate. 

150 """ 

151 return self._gen_pauli_U("y") 

152 

153 @property 

154 def y_H(self) -> Optional[jqt.Qarray]: 

155 """ 

156 Logical Y hamiltonian. 

157 """ 

158 return None 

159 

160 @property 

161 def z_U(self) -> jqt.Qarray: 

162 """ 

163 Logical Z unitary gate. 

164 """ 

165 return self._gen_pauli_U("z") 

166 

167 @property 

168 def z_H(self) -> Optional[jqt.Qarray]: 

169 """ 

170 Logical Z hamiltonian. 

171 """ 

172 return None 

173 

174 @property 

175 def h_H(self) -> Optional[jqt.Qarray]: 

176 """ 

177 Logical Hadamard hamiltonian. 

178 """ 

179 return None 

180 

181 @property 

182 def h_U(self) -> jqt.Qarray: 

183 """ 

184 Logical Hadamard unitary gate. 

185 """ 

186 return ( 

187 self.basis["+x"] @ self.basis["+z"].dag() 

188 + self.basis["-x"] @ self.basis["-z"].dag() 

189 ) 

190 

191 def _gen_pauli_U(self, basis_state: str) -> jqt.Qarray: 

192 """ 

193 Generates unitary for Pauli X, Y, Z. 

194 

195 Args: 

196 basis_state (str): "x", "y", "z" 

197 

198 Returns: 

199 U (jqt.Qarray): Pauli unitary 

200 """ 

201 H = getattr(self, basis_state + "_H") 

202 if H is not None: 

203 return jqt.expm(1.0j * H) 

204 

205 gate = ( 

206 self.basis["+" + basis_state] @ self.basis["+" + basis_state].dag() 

207 - self.basis["-" + basis_state] @ self.basis["-" + basis_state].dag() 

208 ) 

209 

210 return gate 

211 

212 @property 

213 def projector(self): 

214 return ( 

215 self.basis["+z"] @ self.basis["+z"].dag() 

216 + self.basis["-z"] @ self.basis["-z"].dag() 

217 ) 

218 

219 @property 

220 def maximally_mixed_state(self): 

221 return (1 / 2.0) * self.projector 

222 

223 # Plotting 

224 # ====================================================== 

225 def _prepare_state_plot(self, state): 

226 """ 

227 Can be overriden. 

228 

229 E.g. in the case of cavity x transmon system 

230 return qt.ptrace(state, 0) 

231 """ 

232 return state 

233 

234 def plot(self, state, ax=None, qp_type=jqt.WIGNER, **kwargs) -> None: 

235 if ax is None: 

236 fig, ax = plt.subplots(1, figsize=(4, 3), dpi=200) 

237 fig = ax.get_figure() 

238 

239 if qp_type == jqt.WIGNER: 

240 vmin = -1 

241 vmax = 1 

242 elif qp_type == jqt.QFUNC: 

243 vmin = 0 

244 vmax = 1 

245 

246 w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs) 

247 

248 ax.set_title(qp_type.capitalize() + " Quasi-Probability Dist.") 

249 ticks = np.linspace(vmin, vmax, 5) 

250 fig.colorbar(w_plt, ax=ax, ticks=ticks) 

251 ax.set_xlabel(r"Re$(\alpha)$") 

252 ax.set_ylabel(r"Im$(\alpha)$") 

253 fig.tight_layout() 

254 

255 plt.show() 

256 

257 def _plot_single(self, state, ax=None, contour=True, qp_type=jqt.WIGNER): 

258 """ 

259 Assumes state has same dims as initial_state. 

260 """ 

261 

262 if ax is None: 

263 _, ax = plt.subplots(1, figsize=(4, 3), dpi=200) 

264 

265 return jqt.plot_qp( 

266 state, self.wigner_pts, axs=ax, contour=contour, qp_type=qp_type 

267 ) 

268 

269 def plot_code_states(self, qp_type: str = jqt.WIGNER, **kwargs): 

270 """ 

271 Plot |±x⟩, |±y⟩, |±z⟩ code states. 

272 

273 Args: 

274 qp_type (str): 

275 WIGNER or QFUNC 

276 

277 Return: 

278 axs: Axes 

279 """ 

280 fig, axs = plt.subplots(2, 3, figsize=(12, 6), dpi=200) 

281 

282 for i, label in enumerate(["+z", "+x", "+y", "-z", "-x", "-y"]): 

283 state = self._prepare_state_plot(self.basis[label]) 

284 pos = (i // 3, i % 3) 

285 ax = axs[pos] 

286 _, w_plt = self._plot_single(state, ax=ax, qp_type=qp_type, **kwargs) 

287 ax.set_title(f"|{label}" + r"$\rangle$") 

288 ax.set_xlabel(r"Re[$\alpha$]") 

289 ax.set_ylabel(r"Im[$\alpha$]") 

290 

291 fig.suptitle(self.name) 

292 fig.tight_layout() 

293 fig.align_xlabels(axs) 

294 fig.align_ylabels(axs) 

295 

296 fig.tight_layout() 

297 plt.show()