Coverage for jaxquantum/core/measurements.py: 76%

159 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Measurements.""" 

2 

3import optax 

4import jax.numpy as jnp 

5 

6from collections.abc import Callable 

7from matplotlib import pyplot as plt 

8from tqdm import tqdm 

9from typing import Optional, NamedTuple 

10from functools import partial, reduce 

11 

12from jax import config, Array, jit, value_and_grad, lax, vmap 

13 

14from jaxquantum.core.qarray import Qarray, powm 

15from jaxquantum.core.operators import identity 

16 

17from jax_tqdm import scan_tqdm 

18 

19config.update("jax_enable_x64", True) 

20 

21 

22# Calculations ---------------------------------------------------------------- 

23 

24 

25def overlap(rho: Qarray, sigma: Qarray) -> Array: 

26 """Overlap between two states or operators. 

27 

28 Args: 

29 rho: state/operator. 

30 sigma: state/operator. 

31 

32 Returns: 

33 Overlap between rho and sigma. 

34 """ 

35 

36 if rho.is_vec() and sigma.is_vec(): 

37 return jnp.abs(((rho.to_ket().dag() @ sigma.to_ket()).trace())) ** 2 

38 elif rho.is_vec(): 

39 rho = rho.to_ket() 

40 res = (rho.dag() @ sigma @ rho).data 

41 return res.squeeze(-1).squeeze(-1) 

42 elif sigma.is_vec(): 

43 sigma = sigma.to_ket() 

44 res = (sigma.dag() @ rho @ sigma).data 

45 return res.squeeze(-1).squeeze(-1) 

46 else: 

47 return (rho.dag() @ sigma).trace() 

48 

49 

50def fidelity(rho: Qarray, sigma: Qarray, force_positivity: bool=False) -> ( 

51 jnp.ndarray): 

52 """Fidelity between two states. 

53 

54 Args: 

55 rho: state. 

56 sigma: state. 

57 force_positivity: force the states to be positive semidefinite 

58 

59 Returns: 

60 Fidelity between rho and sigma. 

61 """ 

62 rho = rho.to_dm() 

63 sigma = sigma.to_dm() 

64 

65 sqrt_rho = powm(rho, 0.5, clip_eigvals=force_positivity) 

66 

67 return jnp.real(((powm(sqrt_rho @ sigma @ sqrt_rho, 0.5, 

68 clip_eigvals=force_positivity)).tr()) 

69 ** 2) 

70 

71 

72def _reconstruct_density_matrix(params: jnp.ndarray, dim: int) -> jnp.ndarray: 

73 """ 

74 Pure function to parameterize a density matrix. 

75 Ensures the resulting matrix is positive semi-definite and has trace 1. 

76 """ 

77 num_real_params = dim * (dim + 1) // 2 

78 

79 real_part_flat = params[:num_real_params] 

80 imag_part_flat = params[num_real_params:] 

81 

82 T = jnp.zeros((dim, dim), dtype=jnp.complex128) 

83 

84 # Set the real parts of the lower triangle from the first part of params 

85 tril_indices = jnp.tril_indices(dim) 

86 T = T.at[tril_indices].set(real_part_flat) 

87 

88 # Set the imaginary parts of the strictly lower triangle from the second part of params 

89 tril_indices_off_diag = jnp.tril_indices(dim, k=-1) 

90 T = T.at[tril_indices_off_diag].add(1j * imag_part_flat) 

91 

92 rho_unnormalized = T @ T.conj().T 

93 # Enforce trace=1 by dividing by the trace 

94 trace = jnp.trace(rho_unnormalized) 

95 return rho_unnormalized / jnp.where(trace == 0, 1.0, trace) 

96 

97 

98def _parametrize_density_matrix(rho_data: jnp.ndarray, dim: int) -> ( 

99 jnp.ndarray): 

100 """ 

101 Calculates the parameter vector from a density matrix using Cholesky decomposition. 

102 This is the inverse of the _reconstruct_density_matrix function. 

103 """ 

104 # Add a small epsilon for numerical stability, ensuring the matrix is positive definite 

105 T = jnp.linalg.cholesky(rho_data + 1e-9 * jnp.eye(dim)) 

106 

107 # T is lower-triangular with a real, positive diagonal. This matches our 

108 # parameterization convention. 

109 

110 # Extract the real parts of all lower-triangular elements 

111 tril_indices = jnp.tril_indices(dim) 

112 real_part = T[tril_indices].real 

113 

114 # Extract the imaginary parts of the strictly lower-triangular elements 

115 tril_indices_off_diag = jnp.tril_indices(dim, k=-1) 

116 imag_part = T[tril_indices_off_diag].imag 

117 

118 return jnp.concatenate([real_part, imag_part]) 

119 

120 

121def _L1_reg(params: jnp.ndarray) -> jnp.ndarray: 

122 """Pure function for L1 regularization.""" 

123 return jnp.sum(jnp.abs(params)) 

124 

125 

126def _likelihood( 

127 params: jnp.ndarray, dim: int, basis: jnp.ndarray, results: jnp.ndarray 

128) -> jnp.ndarray: 

129 """Compute the log-likelihood.""" 

130 rho = _reconstruct_density_matrix(params, dim) 

131 expected_outcomes = jnp.real(jnp.einsum("ijk,jk->i", basis, rho)) 

132 return -jnp.sum((expected_outcomes - results) ** 2) 

133 

134 

135# This is the core JIT-ted training loop. It is a pure function. 

136@partial( 

137 jit, 

138 static_argnames=[ 

139 "dim", 

140 "epochs", 

141 "optimizer", 

142 "compute_infidelity", 

143 "L1_reg_strength", 

144 ], 

145) 

146def _run_tomography_scan( 

147 initial_params, 

148 initial_opt_state, 

149 true_rho_data, 

150 measurement_basis, 

151 measurement_results, 

152 dim, 

153 epochs, 

154 optimizer, 

155 compute_infidelity, 

156 L1_reg_strength, 

157): 

158 """ 

159 A pure, JIT-compiled function that runs the entire optimization. 

160 Static arguments are those that define the computation graph and don't change during the run. 

161 """ 

162 

163 def loss(params): 

164 log_likelihood = _likelihood( 

165 params, dim, measurement_basis, measurement_results 

166 ) 

167 regularization = L1_reg_strength * _L1_reg(params) 

168 return -log_likelihood + regularization 

169 

170 loss_val_grad = value_and_grad(loss) 

171 

172 @scan_tqdm(epochs) 

173 def train_step(carry, _): 

174 params, opt_state = carry 

175 loss_val, grads = loss_val_grad(params) 

176 updates, new_opt_state = optimizer.update(grads, opt_state, params) 

177 new_params = optax.apply_updates(params, updates) 

178 

179 # This `if` statement is safe inside a JIT-ted function because 

180 # `compute_infidelity` is a "static" argument. JAX compiles a 

181 # separate version of the code for each value of this flag. 

182 if compute_infidelity: 

183 rho = Qarray.create(_reconstruct_density_matrix(params, dim)) 

184 fid = fidelity(Qarray.create(true_rho_data), rho, 

185 force_positivity=True) 

186 infidelity = 1.0 - fid 

187 else: 

188 infidelity = jnp.nan 

189 

190 new_carry = (new_params, new_opt_state) 

191 history = { 

192 "loss": loss_val, 

193 "grads": grads, 

194 "params": params, 

195 "infidelity": infidelity, 

196 } 

197 return new_carry, history 

198 

199 initial_carry = (initial_params, initial_opt_state) 

200 final_carry, history = lax.scan( 

201 train_step, initial_carry, jnp.arange(epochs), length=epochs 

202 ) 

203 return final_carry, history 

204 

205 

206class MLETomographyResult(NamedTuple): 

207 rho: Qarray 

208 params_history: list 

209 loss_history: list 

210 grads_history: list 

211 infidelity_history: Optional[list] 

212 

213 

214class QuantumStateTomography: 

215 def __init__( 

216 self, 

217 rho_guess: Qarray, 

218 measurement_basis: Qarray, 

219 measurement_results: jnp.ndarray, 

220 complete_basis: Optional[Qarray] = None, 

221 true_rho: Optional[Qarray] = None, 

222 ): 

223 """ 

224 Reconstruct a quantum state from measurement results using quantum state tomography. 

225 The tomography can be performed either by direct inversion or by maximum likelihood estimation. 

226 

227 Args: 

228 rho_guess (Qarray): The initial guess for the quantum state. 

229 measurement_basis (Qarray): The basis in which measurements are performed. 

230 measurement_results (jnp.ndarray): The results of the measurements. 

231 complete_basis (Optional[Qarray]): The complete basis for state  

232 reconstruction used when using direct inversion.  

233 Defaults to the measurement basis if not provided. 

234 true_rho (Optional[Qarray]): The true quantum state, if known. 

235 

236 """ 

237 self.rho_guess = rho_guess.data 

238 self.measurement_basis = measurement_basis.data 

239 self.measurement_results = measurement_results 

240 self.complete_basis = ( 

241 complete_basis.data 

242 if (complete_basis is not None) 

243 else measurement_basis.data 

244 ) 

245 self.true_rho = true_rho 

246 self._result = None 

247 

248 @property 

249 def result(self) -> Optional[MLETomographyResult]: 

250 return self._result 

251 

252 

253 def quantum_state_tomography_mle( 

254 self, L1_reg_strength: float = 0.0, epochs: int = 10000, lr: float = 5e-3 

255 ) -> MLETomographyResult: 

256 """Perform quantum state tomography using maximum likelihood  

257 estimation (MLE). 

258 

259 This method reconstructs the quantum state from measurement results  

260 by optimizing 

261 a likelihood function using gradient descent. The optimization  

262 ensures the  

263 resulting density matrix is positive semi-definite with trace 1. 

264 

265 Args: 

266 L1_reg_strength (float, optional): Strength of L1  

267 regularization. Defaults to 0.0. 

268 epochs (int, optional): Number of optimization iterations.  

269 Defaults to 10000. 

270 lr (float, optional): Learning rate for the Adam optimizer.  

271 Defaults to 5e-3. 

272 

273 Returns: 

274 MLETomographyResult: Named tuple containing: 

275 - rho: Reconstructed quantum state as Qarray 

276 - params_history: List of parameter values during optimization 

277 - loss_history: List of loss values during optimization 

278 - grads_history: List of gradient values during optimization 

279 - infidelity_history: List of infidelities if true_rho was  

280 provided, None otherwise 

281 """ 

282 

283 dim = self.rho_guess.shape[0] 

284 optimizer = optax.adam(lr) 

285 

286 # Initialize parameters from the initial guess for the density matrix 

287 params = _parametrize_density_matrix(self.rho_guess, dim) 

288 opt_state = optimizer.init(params) 

289 

290 compute_infidelity_flag = self.true_rho is not None 

291 

292 # Provide a dummy array if no true_rho is available. It won't be used. 

293 true_rho_data_or_dummy = ( 

294 self.true_rho.data 

295 if compute_infidelity_flag 

296 else jnp.empty((dim, dim), dtype=jnp.complex64) 

297 ) 

298 

299 final_carry, history = _run_tomography_scan( 

300 initial_params=params, 

301 initial_opt_state=opt_state, 

302 true_rho_data=true_rho_data_or_dummy, 

303 measurement_basis=self.measurement_basis, 

304 measurement_results=self.measurement_results, 

305 dim=dim, 

306 epochs=epochs, 

307 optimizer=optimizer, 

308 compute_infidelity=compute_infidelity_flag, 

309 L1_reg_strength=L1_reg_strength, 

310 ) 

311 

312 final_params, _ = final_carry 

313 

314 rho = Qarray.create(_reconstruct_density_matrix(final_params, dim)) 

315 

316 self._result = MLETomographyResult( 

317 rho=rho, 

318 params_history=history["params"], 

319 loss_history=history["loss"], 

320 grads_history=history["grads"], 

321 infidelity_history=history["infidelity"] 

322 if compute_infidelity_flag 

323 else None, 

324 ) 

325 return self._result 

326 

327 def quantum_state_tomography_direct( 

328 self, 

329 ) -> Qarray: 

330 

331 """Perform quantum state tomography using direct inversion. 

332  

333 This method reconstructs the quantum state from measurement results by  

334 directly solving a system of linear equations. The method assumes that 

335 the measurement basis is complete and the measurement results are  

336 noise-free. 

337  

338 Returns: 

339 Qarray: Reconstructed quantum state. 

340 """ 

341 

342 # Compute overlaps of measurement and complete operator bases 

343 A = jnp.einsum("ijk,ljk->il", self.complete_basis, self.measurement_basis) 

344 # Solve the linear system to find the coefficients 

345 coefficients = jnp.linalg.solve(A, self.measurement_results) 

346 # Reconstruct the density matrix 

347 rho = jnp.einsum("i, ijk->jk", coefficients, self.complete_basis) 

348 

349 return Qarray.create(rho) 

350 

351 def plot_results(self): 

352 if self._result is None: 

353 raise ValueError( 

354 "No results to plot. Run quantum_state_tomography_mle first." 

355 ) 

356 

357 fig, ax = plt.subplots(1, figsize=(5, 4)) 

358 if self._result.infidelity_history is not None: 

359 ax2 = ax.twinx() 

360 

361 ax.plot(self._result.loss_history, color="C0") 

362 ax.set_xlabel("Epoch") 

363 ax.set_ylabel("$\\mathcal{L}$", color="C0") 

364 ax.set_yscale("log") 

365 

366 if self._result.infidelity_history is not None: 

367 ax2.plot(self._result.infidelity_history, color="C1") 

368 ax2.set_yscale("log") 

369 ax2.set_ylabel("$1-\\mathcal{F}$", color="C1") 

370 plt.grid(False) 

371 

372 plt.show() 

373 

374def tensor_basis(single_basis: Qarray, n: int) -> Qarray: 

375 """Construct n-fold tensor product basis from a single-system basis. 

376 

377 Args: 

378 single_basis: The single-system operator basis as a Qarray. 

379 n: Number of tensor copies to construct. 

380 

381 Returns: 

382 Qarray containing the n-fold tensor product basis operators. 

383 The resulting basis has b^n elements where b is the number 

384 of operators in the single-system basis. 

385 """ 

386 

387 dims = single_basis.dims 

388 

389 single_basis = single_basis.data 

390 b, d, _ = single_basis.shape 

391 indices = jnp.stack(jnp.meshgrid(*[jnp.arange(b)] * n, indexing="ij"), 

392 axis=-1).reshape(-1, n) # shape (b^n, n) 

393 

394 # Select the operators based on indices: shape (b^n, n, d, d) 

395 selected = single_basis[indices] # shape: (b^n, n, d, d) 

396 

397 # Vectorized Kronecker products 

398 full_basis = vmap(lambda ops: reduce(jnp.kron, ops))(selected) 

399 

400 new_dims = tuple(tuple(x**n for x in row) for row in dims) 

401 

402 return Qarray.create(full_basis, dims=new_dims, bdims=(b**n,)) 

403 

404 

405def _quantum_process_tomography( 

406 map: Callable[[Qarray], Qarray], 

407 physical_state_basis: Qarray, 

408 physical_operator_basis: Qarray, 

409 logical_state_basis: Optional[Qarray] = None, 

410 logical_operator_basis: Optional[Qarray] = None, 

411) -> Qarray: 

412 

413 if logical_state_basis is None: 

414 logical_state_basis = physical_state_basis 

415 if logical_operator_basis is None: 

416 logical_operator_basis = physical_operator_basis 

417 

418 dsqr = logical_state_basis.bdims[-1] 

419 

420 d = int(jnp.sqrt(dsqr+1)) 

421 

422 choi = Qarray.create(jnp.zeros((d, d))) ^ Qarray.create(jnp.zeros((d, d))) 

423 

424 with (tqdm(total=d * d) as pbar): 

425 for k in range(dsqr): 

426 rho_k = physical_state_basis[k] @ physical_state_basis[k].dag() 

427 

428 E_rho_k = map(rho_k) 

429 

430 measurement_results = jnp.real( 

431 jnp.einsum('ijk,jk->i', physical_operator_basis.data, 

432 E_rho_k.data)) 

433 

434 QST = QuantumStateTomography(rho_guess=identity(d)/d, 

435 measurement_results=measurement_results, 

436 measurement_basis=logical_operator_basis, 

437 ) 

438 res = QST.quantum_state_tomography_mle() 

439 r = res.rho 

440 print("----------------------") 

441 print(logical_state_basis[k] @ logical_state_basis[k].dag()) 

442 print(r) 

443 choi += (logical_state_basis[k] @ logical_state_basis[k].dag() 

444 ) ^ r 

445 pbar.update(1) 

446 return choi