Skip to content

measurements

Measurements.

QuantumStateTomography

Source code in jaxquantum/core/measurements.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class QuantumStateTomography:
    def __init__(
        self,
        rho_guess: Qarray,
        measurement_basis: Qarray,
        measurement_results: jnp.ndarray,
        complete_basis: Optional[Qarray] = None,
        true_rho: Optional[Qarray] = None,
    ):
        """
        Reconstruct a quantum state from measurement results using quantum state tomography.
        The tomography can be performed either by direct inversion or by maximum likelihood estimation.

        Args:
            rho_guess (Qarray): The initial guess for the quantum state.
            measurement_basis (Qarray): The basis in which measurements are performed.
            measurement_results (jnp.ndarray): The results of the measurements.
            complete_basis (Optional[Qarray]): The complete basis for state 
            reconstruction used when using direct inversion. 
            Defaults to the measurement basis if not provided.
            true_rho (Optional[Qarray]): The true quantum state, if known.

        """
        self.rho_guess = rho_guess.data
        self.measurement_basis = measurement_basis.data
        self.measurement_results = measurement_results
        self.complete_basis = (
            complete_basis.data
            if (complete_basis is not None)
            else measurement_basis.data
        )
        self.true_rho = true_rho
        self._result = None

    @property
    def result(self) -> Optional[MLETomographyResult]:
        return self._result


    def quantum_state_tomography_mle(
        self, L1_reg_strength: float = 0.0, epochs: int = 10000, lr: float = 5e-3
    ) -> MLETomographyResult:
        """Perform quantum state tomography using maximum likelihood 
        estimation (MLE).

        This method reconstructs the quantum state from measurement results 
        by optimizing
        a likelihood function using gradient descent. The optimization 
        ensures the 
        resulting density matrix is positive semi-definite with trace 1.

        Args:
            L1_reg_strength (float, optional): Strength of L1 
            regularization. Defaults to 0.0.
            epochs (int, optional): Number of optimization iterations. 
            Defaults to 10000.
            lr (float, optional): Learning rate for the Adam optimizer. 
            Defaults to 5e-3.

        Returns:
            MLETomographyResult: Named tuple containing:
                - rho: Reconstructed quantum state as Qarray
                - params_history: List of parameter values during optimization
                - loss_history: List of loss values during optimization
                - grads_history: List of gradient values during optimization
                - infidelity_history: List of infidelities if true_rho was 
                provided, None otherwise
        """

        dim = self.rho_guess.shape[0]
        optimizer = optax.adam(lr)

        # Initialize parameters from the initial guess for the density matrix
        params = _parametrize_density_matrix(self.rho_guess, dim)
        opt_state = optimizer.init(params)

        compute_infidelity_flag = self.true_rho is not None

        # Provide a dummy array if no true_rho is available. It won't be used.
        true_rho_data_or_dummy = (
            self.true_rho.data
            if compute_infidelity_flag
            else jnp.empty((dim, dim), dtype=jnp.complex64)
        )

        final_carry, history = _run_tomography_scan(
            initial_params=params,
            initial_opt_state=opt_state,
            true_rho_data=true_rho_data_or_dummy,
            measurement_basis=self.measurement_basis,
            measurement_results=self.measurement_results,
            dim=dim,
            epochs=epochs,
            optimizer=optimizer,
            compute_infidelity=compute_infidelity_flag,
            L1_reg_strength=L1_reg_strength,
        )

        final_params, _ = final_carry

        rho = Qarray.create(_reconstruct_density_matrix(final_params, dim))

        self._result = MLETomographyResult(
            rho=rho,
            params_history=history["params"],
            loss_history=history["loss"],
            grads_history=history["grads"],
            infidelity_history=history["infidelity"]
            if compute_infidelity_flag
            else None,
        )
        return self._result

    def quantum_state_tomography_direct(
        self,
    ) -> Qarray:

        """Perform quantum state tomography using direct inversion.

        This method reconstructs the quantum state from measurement results by 
        directly solving a system of linear equations. The method assumes that
        the measurement basis is complete and the measurement results are 
        noise-free.

        Returns:
            Qarray: Reconstructed quantum state.
        """

    # Compute overlaps of measurement and complete operator bases
        A = jnp.einsum("ijk,ljk->il", self.complete_basis, self.measurement_basis)
        # Solve the linear system to find the coefficients
        coefficients = jnp.linalg.solve(A, self.measurement_results)
        # Reconstruct the density matrix
        rho = jnp.einsum("i, ijk->jk", coefficients, self.complete_basis)

        return Qarray.create(rho)

    def plot_results(self):
        if self._result is None:
            raise ValueError(
                "No results to plot. Run quantum_state_tomography_mle first."
            )

        fig, ax = plt.subplots(1, figsize=(5, 4))
        if self._result.infidelity_history is not None:
            ax2 = ax.twinx()

        ax.plot(self._result.loss_history, color="C0")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("$\\mathcal{L}$", color="C0")
        ax.set_yscale("log")

        if self._result.infidelity_history is not None:
            ax2.plot(self._result.infidelity_history, color="C1")
            ax2.set_yscale("log")
            ax2.set_ylabel("$1-\\mathcal{F}$", color="C1")
            plt.grid(False)

        plt.show()

__init__(rho_guess, measurement_basis, measurement_results, complete_basis=None, true_rho=None)

Reconstruct a quantum state from measurement results using quantum state tomography. The tomography can be performed either by direct inversion or by maximum likelihood estimation.

Parameters:

Name Type Description Default
rho_guess Qarray

The initial guess for the quantum state.

required
measurement_basis Qarray

The basis in which measurements are performed.

required
measurement_results ndarray

The results of the measurements.

required
complete_basis Optional[Qarray]

The complete basis for state

None
true_rho Optional[Qarray]

The true quantum state, if known.

None
Source code in jaxquantum/core/measurements.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def __init__(
    self,
    rho_guess: Qarray,
    measurement_basis: Qarray,
    measurement_results: jnp.ndarray,
    complete_basis: Optional[Qarray] = None,
    true_rho: Optional[Qarray] = None,
):
    """
    Reconstruct a quantum state from measurement results using quantum state tomography.
    The tomography can be performed either by direct inversion or by maximum likelihood estimation.

    Args:
        rho_guess (Qarray): The initial guess for the quantum state.
        measurement_basis (Qarray): The basis in which measurements are performed.
        measurement_results (jnp.ndarray): The results of the measurements.
        complete_basis (Optional[Qarray]): The complete basis for state 
        reconstruction used when using direct inversion. 
        Defaults to the measurement basis if not provided.
        true_rho (Optional[Qarray]): The true quantum state, if known.

    """
    self.rho_guess = rho_guess.data
    self.measurement_basis = measurement_basis.data
    self.measurement_results = measurement_results
    self.complete_basis = (
        complete_basis.data
        if (complete_basis is not None)
        else measurement_basis.data
    )
    self.true_rho = true_rho
    self._result = None

quantum_state_tomography_direct()

Perform quantum state tomography using direct inversion.

This method reconstructs the quantum state from measurement results by directly solving a system of linear equations. The method assumes that the measurement basis is complete and the measurement results are noise-free.

Returns:

Name Type Description
Qarray Qarray

Reconstructed quantum state.

Source code in jaxquantum/core/measurements.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def quantum_state_tomography_direct(
    self,
) -> Qarray:

    """Perform quantum state tomography using direct inversion.

    This method reconstructs the quantum state from measurement results by 
    directly solving a system of linear equations. The method assumes that
    the measurement basis is complete and the measurement results are 
    noise-free.

    Returns:
        Qarray: Reconstructed quantum state.
    """

# Compute overlaps of measurement and complete operator bases
    A = jnp.einsum("ijk,ljk->il", self.complete_basis, self.measurement_basis)
    # Solve the linear system to find the coefficients
    coefficients = jnp.linalg.solve(A, self.measurement_results)
    # Reconstruct the density matrix
    rho = jnp.einsum("i, ijk->jk", coefficients, self.complete_basis)

    return Qarray.create(rho)

quantum_state_tomography_mle(L1_reg_strength=0.0, epochs=10000, lr=0.005)

Perform quantum state tomography using maximum likelihood estimation (MLE).

This method reconstructs the quantum state from measurement results by optimizing a likelihood function using gradient descent. The optimization ensures the resulting density matrix is positive semi-definite with trace 1.

Parameters:

Name Type Description Default
L1_reg_strength float

Strength of L1

0.0
epochs int

Number of optimization iterations.

10000
lr float

Learning rate for the Adam optimizer.

0.005

Returns:

Name Type Description
MLETomographyResult MLETomographyResult

Named tuple containing: - rho: Reconstructed quantum state as Qarray - params_history: List of parameter values during optimization - loss_history: List of loss values during optimization - grads_history: List of gradient values during optimization - infidelity_history: List of infidelities if true_rho was provided, None otherwise

Source code in jaxquantum/core/measurements.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def quantum_state_tomography_mle(
    self, L1_reg_strength: float = 0.0, epochs: int = 10000, lr: float = 5e-3
) -> MLETomographyResult:
    """Perform quantum state tomography using maximum likelihood 
    estimation (MLE).

    This method reconstructs the quantum state from measurement results 
    by optimizing
    a likelihood function using gradient descent. The optimization 
    ensures the 
    resulting density matrix is positive semi-definite with trace 1.

    Args:
        L1_reg_strength (float, optional): Strength of L1 
        regularization. Defaults to 0.0.
        epochs (int, optional): Number of optimization iterations. 
        Defaults to 10000.
        lr (float, optional): Learning rate for the Adam optimizer. 
        Defaults to 5e-3.

    Returns:
        MLETomographyResult: Named tuple containing:
            - rho: Reconstructed quantum state as Qarray
            - params_history: List of parameter values during optimization
            - loss_history: List of loss values during optimization
            - grads_history: List of gradient values during optimization
            - infidelity_history: List of infidelities if true_rho was 
            provided, None otherwise
    """

    dim = self.rho_guess.shape[0]
    optimizer = optax.adam(lr)

    # Initialize parameters from the initial guess for the density matrix
    params = _parametrize_density_matrix(self.rho_guess, dim)
    opt_state = optimizer.init(params)

    compute_infidelity_flag = self.true_rho is not None

    # Provide a dummy array if no true_rho is available. It won't be used.
    true_rho_data_or_dummy = (
        self.true_rho.data
        if compute_infidelity_flag
        else jnp.empty((dim, dim), dtype=jnp.complex64)
    )

    final_carry, history = _run_tomography_scan(
        initial_params=params,
        initial_opt_state=opt_state,
        true_rho_data=true_rho_data_or_dummy,
        measurement_basis=self.measurement_basis,
        measurement_results=self.measurement_results,
        dim=dim,
        epochs=epochs,
        optimizer=optimizer,
        compute_infidelity=compute_infidelity_flag,
        L1_reg_strength=L1_reg_strength,
    )

    final_params, _ = final_carry

    rho = Qarray.create(_reconstruct_density_matrix(final_params, dim))

    self._result = MLETomographyResult(
        rho=rho,
        params_history=history["params"],
        loss_history=history["loss"],
        grads_history=history["grads"],
        infidelity_history=history["infidelity"]
        if compute_infidelity_flag
        else None,
    )
    return self._result

fidelity(rho, sigma, force_positivity=False)

Fidelity between two states.

Parameters:

Name Type Description Default
rho Qarray

state.

required
sigma Qarray

state.

required
force_positivity bool

force the states to be positive semidefinite

False

Returns:

Type Description
ndarray

Fidelity between rho and sigma.

Source code in jaxquantum/core/measurements.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def fidelity(rho: Qarray, sigma: Qarray, force_positivity: bool=False) -> (
        jnp.ndarray):
    """Fidelity between two states.

    Args:
        rho: state.
        sigma: state.
        force_positivity: force the states to be positive semidefinite

    Returns:
        Fidelity between rho and sigma.
    """
    rho = rho.to_dm()
    sigma = sigma.to_dm()

    sqrt_rho = powm(rho, 0.5, clip_eigvals=force_positivity)

    return jnp.real(((powm(sqrt_rho @ sigma @ sqrt_rho, 0.5,
                           clip_eigvals=force_positivity)).tr())
                    ** 2)

overlap(rho, sigma)

Overlap between two states or operators.

Parameters:

Name Type Description Default
rho Qarray

state/operator.

required
sigma Qarray

state/operator.

required

Returns:

Type Description
Array

Overlap between rho and sigma.

Source code in jaxquantum/core/measurements.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def overlap(rho: Qarray, sigma: Qarray) -> Array:
    """Overlap between two states or operators.

    Args:
        rho: state/operator.
        sigma: state/operator.

    Returns:
        Overlap between rho and sigma.
    """

    if rho.is_vec() and sigma.is_vec():
        return jnp.abs(((rho.to_ket().dag() @ sigma.to_ket()).trace())) ** 2
    elif rho.is_vec():
        rho = rho.to_ket()
        res = (rho.dag() @ sigma @ rho).data
        return res.squeeze(-1).squeeze(-1)
    elif sigma.is_vec():
        sigma = sigma.to_ket()
        res = (sigma.dag() @ rho @ sigma).data
        return res.squeeze(-1).squeeze(-1)
    else:
        return (rho.dag() @ sigma).trace()

tensor_basis(single_basis, n)

Construct n-fold tensor product basis from a single-system basis.

Parameters:

Name Type Description Default
single_basis Qarray

The single-system operator basis as a Qarray.

required
n int

Number of tensor copies to construct.

required

Returns:

Type Description
Qarray

Qarray containing the n-fold tensor product basis operators.

Qarray

The resulting basis has b^n elements where b is the number

Qarray

of operators in the single-system basis.

Source code in jaxquantum/core/measurements.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def tensor_basis(single_basis: Qarray, n: int) -> Qarray:
    """Construct n-fold tensor product basis from a single-system basis.

    Args:
        single_basis: The single-system operator basis as a Qarray.
        n: Number of tensor copies to construct.

    Returns:
        Qarray containing the n-fold tensor product basis operators.
        The resulting basis has b^n elements where b is the number
        of operators in the single-system basis.
    """

    dims = single_basis.dims

    single_basis = single_basis.data
    b, d, _ = single_basis.shape
    indices = jnp.stack(jnp.meshgrid(*[jnp.arange(b)] * n, indexing="ij"),
                        axis=-1).reshape(-1, n)  # shape (b^n, n)

    # Select the operators based on indices: shape (b^n, n, d, d)
    selected = single_basis[indices]  # shape: (b^n, n, d, d)

    # Vectorized Kronecker products
    full_basis = vmap(lambda ops: reduce(jnp.kron, ops))(selected)

    new_dims = tuple(tuple(x**n for x in row) for row in dims)

    return Qarray.create(full_basis, dims=new_dims, bdims=(b**n,))