Coverage for jaxquantum/core/measurements.py: 76%
159 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Measurements."""
3import optax
4import jax.numpy as jnp
6from collections.abc import Callable
7from matplotlib import pyplot as plt
8from tqdm import tqdm
9from typing import Optional, NamedTuple
10from functools import partial, reduce
12from jax import config, Array, jit, value_and_grad, lax, vmap
14from jaxquantum.core.qarray import Qarray, powm
15from jaxquantum.core.operators import identity
17from jax_tqdm import scan_tqdm
19config.update("jax_enable_x64", True)
22# Calculations ----------------------------------------------------------------
25def overlap(rho: Qarray, sigma: Qarray) -> Array:
26 """Overlap between two states or operators.
28 Args:
29 rho: state/operator.
30 sigma: state/operator.
32 Returns:
33 Overlap between rho and sigma.
34 """
36 if rho.is_vec() and sigma.is_vec():
37 return jnp.abs(((rho.to_ket().dag() @ sigma.to_ket()).trace())) ** 2
38 elif rho.is_vec():
39 rho = rho.to_ket()
40 res = (rho.dag() @ sigma @ rho).data
41 return res.squeeze(-1).squeeze(-1)
42 elif sigma.is_vec():
43 sigma = sigma.to_ket()
44 res = (sigma.dag() @ rho @ sigma).data
45 return res.squeeze(-1).squeeze(-1)
46 else:
47 return (rho.dag() @ sigma).trace()
50def fidelity(rho: Qarray, sigma: Qarray, force_positivity: bool=False) -> (
51 jnp.ndarray):
52 """Fidelity between two states.
54 Args:
55 rho: state.
56 sigma: state.
57 force_positivity: force the states to be positive semidefinite
59 Returns:
60 Fidelity between rho and sigma.
61 """
62 rho = rho.to_dm()
63 sigma = sigma.to_dm()
65 sqrt_rho = powm(rho, 0.5, clip_eigvals=force_positivity)
67 return jnp.real(((powm(sqrt_rho @ sigma @ sqrt_rho, 0.5,
68 clip_eigvals=force_positivity)).tr())
69 ** 2)
72def _reconstruct_density_matrix(params: jnp.ndarray, dim: int) -> jnp.ndarray:
73 """
74 Pure function to parameterize a density matrix.
75 Ensures the resulting matrix is positive semi-definite and has trace 1.
76 """
77 num_real_params = dim * (dim + 1) // 2
79 real_part_flat = params[:num_real_params]
80 imag_part_flat = params[num_real_params:]
82 T = jnp.zeros((dim, dim), dtype=jnp.complex128)
84 # Set the real parts of the lower triangle from the first part of params
85 tril_indices = jnp.tril_indices(dim)
86 T = T.at[tril_indices].set(real_part_flat)
88 # Set the imaginary parts of the strictly lower triangle from the second part of params
89 tril_indices_off_diag = jnp.tril_indices(dim, k=-1)
90 T = T.at[tril_indices_off_diag].add(1j * imag_part_flat)
92 rho_unnormalized = T @ T.conj().T
93 # Enforce trace=1 by dividing by the trace
94 trace = jnp.trace(rho_unnormalized)
95 return rho_unnormalized / jnp.where(trace == 0, 1.0, trace)
98def _parametrize_density_matrix(rho_data: jnp.ndarray, dim: int) -> (
99 jnp.ndarray):
100 """
101 Calculates the parameter vector from a density matrix using Cholesky decomposition.
102 This is the inverse of the _reconstruct_density_matrix function.
103 """
104 # Add a small epsilon for numerical stability, ensuring the matrix is positive definite
105 T = jnp.linalg.cholesky(rho_data + 1e-9 * jnp.eye(dim))
107 # T is lower-triangular with a real, positive diagonal. This matches our
108 # parameterization convention.
110 # Extract the real parts of all lower-triangular elements
111 tril_indices = jnp.tril_indices(dim)
112 real_part = T[tril_indices].real
114 # Extract the imaginary parts of the strictly lower-triangular elements
115 tril_indices_off_diag = jnp.tril_indices(dim, k=-1)
116 imag_part = T[tril_indices_off_diag].imag
118 return jnp.concatenate([real_part, imag_part])
121def _L1_reg(params: jnp.ndarray) -> jnp.ndarray:
122 """Pure function for L1 regularization."""
123 return jnp.sum(jnp.abs(params))
126def _likelihood(
127 params: jnp.ndarray, dim: int, basis: jnp.ndarray, results: jnp.ndarray
128) -> jnp.ndarray:
129 """Compute the log-likelihood."""
130 rho = _reconstruct_density_matrix(params, dim)
131 expected_outcomes = jnp.real(jnp.einsum("ijk,jk->i", basis, rho))
132 return -jnp.sum((expected_outcomes - results) ** 2)
135# This is the core JIT-ted training loop. It is a pure function.
136@partial(
137 jit,
138 static_argnames=[
139 "dim",
140 "epochs",
141 "optimizer",
142 "compute_infidelity",
143 "L1_reg_strength",
144 ],
145)
146def _run_tomography_scan(
147 initial_params,
148 initial_opt_state,
149 true_rho_data,
150 measurement_basis,
151 measurement_results,
152 dim,
153 epochs,
154 optimizer,
155 compute_infidelity,
156 L1_reg_strength,
157):
158 """
159 A pure, JIT-compiled function that runs the entire optimization.
160 Static arguments are those that define the computation graph and don't change during the run.
161 """
163 def loss(params):
164 log_likelihood = _likelihood(
165 params, dim, measurement_basis, measurement_results
166 )
167 regularization = L1_reg_strength * _L1_reg(params)
168 return -log_likelihood + regularization
170 loss_val_grad = value_and_grad(loss)
172 @scan_tqdm(epochs)
173 def train_step(carry, _):
174 params, opt_state = carry
175 loss_val, grads = loss_val_grad(params)
176 updates, new_opt_state = optimizer.update(grads, opt_state, params)
177 new_params = optax.apply_updates(params, updates)
179 # This `if` statement is safe inside a JIT-ted function because
180 # `compute_infidelity` is a "static" argument. JAX compiles a
181 # separate version of the code for each value of this flag.
182 if compute_infidelity:
183 rho = Qarray.create(_reconstruct_density_matrix(params, dim))
184 fid = fidelity(Qarray.create(true_rho_data), rho,
185 force_positivity=True)
186 infidelity = 1.0 - fid
187 else:
188 infidelity = jnp.nan
190 new_carry = (new_params, new_opt_state)
191 history = {
192 "loss": loss_val,
193 "grads": grads,
194 "params": params,
195 "infidelity": infidelity,
196 }
197 return new_carry, history
199 initial_carry = (initial_params, initial_opt_state)
200 final_carry, history = lax.scan(
201 train_step, initial_carry, jnp.arange(epochs), length=epochs
202 )
203 return final_carry, history
206class MLETomographyResult(NamedTuple):
207 rho: Qarray
208 params_history: list
209 loss_history: list
210 grads_history: list
211 infidelity_history: Optional[list]
214class QuantumStateTomography:
215 def __init__(
216 self,
217 rho_guess: Qarray,
218 measurement_basis: Qarray,
219 measurement_results: jnp.ndarray,
220 complete_basis: Optional[Qarray] = None,
221 true_rho: Optional[Qarray] = None,
222 ):
223 """
224 Reconstruct a quantum state from measurement results using quantum state tomography.
225 The tomography can be performed either by direct inversion or by maximum likelihood estimation.
227 Args:
228 rho_guess (Qarray): The initial guess for the quantum state.
229 measurement_basis (Qarray): The basis in which measurements are performed.
230 measurement_results (jnp.ndarray): The results of the measurements.
231 complete_basis (Optional[Qarray]): The complete basis for state
232 reconstruction used when using direct inversion.
233 Defaults to the measurement basis if not provided.
234 true_rho (Optional[Qarray]): The true quantum state, if known.
236 """
237 self.rho_guess = rho_guess.data
238 self.measurement_basis = measurement_basis.data
239 self.measurement_results = measurement_results
240 self.complete_basis = (
241 complete_basis.data
242 if (complete_basis is not None)
243 else measurement_basis.data
244 )
245 self.true_rho = true_rho
246 self._result = None
248 @property
249 def result(self) -> Optional[MLETomographyResult]:
250 return self._result
253 def quantum_state_tomography_mle(
254 self, L1_reg_strength: float = 0.0, epochs: int = 10000, lr: float = 5e-3
255 ) -> MLETomographyResult:
256 """Perform quantum state tomography using maximum likelihood
257 estimation (MLE).
259 This method reconstructs the quantum state from measurement results
260 by optimizing
261 a likelihood function using gradient descent. The optimization
262 ensures the
263 resulting density matrix is positive semi-definite with trace 1.
265 Args:
266 L1_reg_strength (float, optional): Strength of L1
267 regularization. Defaults to 0.0.
268 epochs (int, optional): Number of optimization iterations.
269 Defaults to 10000.
270 lr (float, optional): Learning rate for the Adam optimizer.
271 Defaults to 5e-3.
273 Returns:
274 MLETomographyResult: Named tuple containing:
275 - rho: Reconstructed quantum state as Qarray
276 - params_history: List of parameter values during optimization
277 - loss_history: List of loss values during optimization
278 - grads_history: List of gradient values during optimization
279 - infidelity_history: List of infidelities if true_rho was
280 provided, None otherwise
281 """
283 dim = self.rho_guess.shape[0]
284 optimizer = optax.adam(lr)
286 # Initialize parameters from the initial guess for the density matrix
287 params = _parametrize_density_matrix(self.rho_guess, dim)
288 opt_state = optimizer.init(params)
290 compute_infidelity_flag = self.true_rho is not None
292 # Provide a dummy array if no true_rho is available. It won't be used.
293 true_rho_data_or_dummy = (
294 self.true_rho.data
295 if compute_infidelity_flag
296 else jnp.empty((dim, dim), dtype=jnp.complex64)
297 )
299 final_carry, history = _run_tomography_scan(
300 initial_params=params,
301 initial_opt_state=opt_state,
302 true_rho_data=true_rho_data_or_dummy,
303 measurement_basis=self.measurement_basis,
304 measurement_results=self.measurement_results,
305 dim=dim,
306 epochs=epochs,
307 optimizer=optimizer,
308 compute_infidelity=compute_infidelity_flag,
309 L1_reg_strength=L1_reg_strength,
310 )
312 final_params, _ = final_carry
314 rho = Qarray.create(_reconstruct_density_matrix(final_params, dim))
316 self._result = MLETomographyResult(
317 rho=rho,
318 params_history=history["params"],
319 loss_history=history["loss"],
320 grads_history=history["grads"],
321 infidelity_history=history["infidelity"]
322 if compute_infidelity_flag
323 else None,
324 )
325 return self._result
327 def quantum_state_tomography_direct(
328 self,
329 ) -> Qarray:
331 """Perform quantum state tomography using direct inversion.
333 This method reconstructs the quantum state from measurement results by
334 directly solving a system of linear equations. The method assumes that
335 the measurement basis is complete and the measurement results are
336 noise-free.
338 Returns:
339 Qarray: Reconstructed quantum state.
340 """
342 # Compute overlaps of measurement and complete operator bases
343 A = jnp.einsum("ijk,ljk->il", self.complete_basis, self.measurement_basis)
344 # Solve the linear system to find the coefficients
345 coefficients = jnp.linalg.solve(A, self.measurement_results)
346 # Reconstruct the density matrix
347 rho = jnp.einsum("i, ijk->jk", coefficients, self.complete_basis)
349 return Qarray.create(rho)
351 def plot_results(self):
352 if self._result is None:
353 raise ValueError(
354 "No results to plot. Run quantum_state_tomography_mle first."
355 )
357 fig, ax = plt.subplots(1, figsize=(5, 4))
358 if self._result.infidelity_history is not None:
359 ax2 = ax.twinx()
361 ax.plot(self._result.loss_history, color="C0")
362 ax.set_xlabel("Epoch")
363 ax.set_ylabel("$\\mathcal{L}$", color="C0")
364 ax.set_yscale("log")
366 if self._result.infidelity_history is not None:
367 ax2.plot(self._result.infidelity_history, color="C1")
368 ax2.set_yscale("log")
369 ax2.set_ylabel("$1-\\mathcal{F}$", color="C1")
370 plt.grid(False)
372 plt.show()
374def tensor_basis(single_basis: Qarray, n: int) -> Qarray:
375 """Construct n-fold tensor product basis from a single-system basis.
377 Args:
378 single_basis: The single-system operator basis as a Qarray.
379 n: Number of tensor copies to construct.
381 Returns:
382 Qarray containing the n-fold tensor product basis operators.
383 The resulting basis has b^n elements where b is the number
384 of operators in the single-system basis.
385 """
387 dims = single_basis.dims
389 single_basis = single_basis.data
390 b, d, _ = single_basis.shape
391 indices = jnp.stack(jnp.meshgrid(*[jnp.arange(b)] * n, indexing="ij"),
392 axis=-1).reshape(-1, n) # shape (b^n, n)
394 # Select the operators based on indices: shape (b^n, n, d, d)
395 selected = single_basis[indices] # shape: (b^n, n, d, d)
397 # Vectorized Kronecker products
398 full_basis = vmap(lambda ops: reduce(jnp.kron, ops))(selected)
400 new_dims = tuple(tuple(x**n for x in row) for row in dims)
402 return Qarray.create(full_basis, dims=new_dims, bdims=(b**n,))
405def _quantum_process_tomography(
406 map: Callable[[Qarray], Qarray],
407 physical_state_basis: Qarray,
408 physical_operator_basis: Qarray,
409 logical_state_basis: Optional[Qarray] = None,
410 logical_operator_basis: Optional[Qarray] = None,
411) -> Qarray:
413 if logical_state_basis is None:
414 logical_state_basis = physical_state_basis
415 if logical_operator_basis is None:
416 logical_operator_basis = physical_operator_basis
418 dsqr = logical_state_basis.bdims[-1]
420 d = int(jnp.sqrt(dsqr+1))
422 choi = Qarray.create(jnp.zeros((d, d))) ^ Qarray.create(jnp.zeros((d, d)))
424 with (tqdm(total=d * d) as pbar):
425 for k in range(dsqr):
426 rho_k = physical_state_basis[k] @ physical_state_basis[k].dag()
428 E_rho_k = map(rho_k)
430 measurement_results = jnp.real(
431 jnp.einsum('ijk,jk->i', physical_operator_basis.data,
432 E_rho_k.data))
434 QST = QuantumStateTomography(rho_guess=identity(d)/d,
435 measurement_results=measurement_results,
436 measurement_basis=logical_operator_basis,
437 )
438 res = QST.quantum_state_tomography_mle()
439 r = res.rho
440 print("----------------------")
441 print(logical_state_basis[k] @ logical_state_basis[k].dag())
442 print(r)
443 choi += (logical_state_basis[k] @ logical_state_basis[k].dag()
444 ) ^ r
445 pbar.update(1)
446 return choi