Coverage for jaxquantum/core/measurements.py: 35%
66 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Helpers."""
3from typing import List
4from jax import config, Array
6import jax.numpy as jnp
7from tqdm import tqdm
9from jaxquantum.core.qarray import Qarray, powm
10from jaxquantum.core.operators import identity, sigmax, sigmay, sigmaz
12config.update("jax_enable_x64", True)
15# Calculations ----------------------------------------------------------------
18def overlap(rho: Qarray, sigma: Qarray) -> Array:
19 """Overlap between two states or operators.
21 Args:
22 rho: state/operator.
23 sigma: state/operator.
25 Returns:
26 Overlap between rho and sigma.
27 """
29 if rho.is_vec() and sigma.is_vec():
30 return jnp.abs(((rho.to_ket().dag() @ sigma.to_ket()).trace())) ** 2
31 elif rho.is_vec():
32 rho = rho.to_ket()
33 res = (rho.dag() @ sigma @ rho).data
34 return res.squeeze(-1).squeeze(-1)
35 elif sigma.is_vec():
36 sigma = sigma.to_ket()
37 res = (sigma.dag() @ rho @ sigma).data
38 return res.squeeze(-1).squeeze(-1)
39 else:
40 return (rho.dag() @ sigma).trace()
43def fidelity(rho: Qarray, sigma: Qarray) -> float:
44 """Fidelity between two states.
46 Args:
47 rho: state.
48 sigma: state.
50 Returns:
51 Fidelity between rho and sigma.
52 """
53 rho = rho.to_dm()
54 sigma = sigma.to_dm()
56 sqrt_rho = powm(rho, 0.5)
58 return ((powm(sqrt_rho @ sigma @ sqrt_rho, 0.5)).tr()) ** 2
61def quantum_state_tomography(
62 rho: Qarray, physical_basis: Qarray, logical_basis: Qarray
63) -> Qarray:
64 """Perform quantum state tomography to retrieve the density matrix in
65 the logical basis.
67 Args:
68 rho: state expressed in the physical Hilbert space basis.
69 physical_basis: list of logical operators expressed in the physical
70 Hilbert space basis forming a complete logical operator basis.
71 logical_basis: list of logical operators expressed in the
72 logical Hilbert space basis forming a complete operator basis.
75 Returns:
76 Density matrix of state rho expressed in the logical basis.
77 """
78 dm = jnp.zeros_like(logical_basis[0].data)
79 rho = rho.to_dm()
81 if physical_basis.bdims[-1] != logical_basis.bdims[-1]:
82 raise ValueError(
83 f"The two bases should have the same size for the "
84 f"last batch dimension. Received "
85 f"{physical_basis.bdims} and {logical_basis.bdims} "
86 f"instead."
87 )
89 space_size = physical_basis.bdims[-1]
91 for i in tqdm(range(space_size), total=space_size):
92 p_i = (rho @ physical_basis[i]).trace()
93 dm += p_i * logical_basis[i].data
95 return Qarray.create(dm, dims=logical_basis.dims, bdims=physical_basis[0].bdims)
98def get_physical_basis(qubits: List) -> Qarray:
99 """Compute a complete operator basis of a QEC code on a
100 physical system specified by a number of qubits.
102 Args:
103 qubits: list of qubit codes, must have
104 common_gates and params attributes.
106 Returns:
107 List containing the complete operator basis.
108 """
110 qubit = qubits[0]
111 qubits = qubits[1:]
112 try:
113 operators = Qarray.from_list(
114 [
115 identity(qubit.params["N"]),
116 qubit.common_gates["X"],
117 qubit.common_gates["Y"],
118 qubit.common_gates["Z"],
119 ]
120 )
121 except KeyError:
122 print("QEC code must have common_gates for all three axes.")
123 except AttributeError:
124 print("QEC code must have common_gates and params attribute.")
126 if len(qubits) == 0:
127 return operators
129 sub_basis = get_physical_basis(qubits)
130 basis = []
132 sub_basis_size = sub_basis.bdims[-1]
134 for i in range(4):
135 for j in range(sub_basis_size):
136 basis.append(operators[i] ^ sub_basis[j])
138 return Qarray.from_list(basis)
141def get_logical_basis(n_qubits: int) -> Qarray:
142 """Compute a complete operator basis of a system composed of logical
143 qubits.
145 Args:
146 n_qubits: number of qubits
148 Returns:
149 List containing the complete operator basis.
150 """
151 if n_qubits < 1:
152 raise ValueError("n_qubits must be at least 1.")
154 n_qubits -= 1
156 operators = Qarray.from_list(
157 [identity(2) / 2, sigmax() / 2, sigmay() / 2, sigmaz() / 2]
158 )
160 if n_qubits == 0:
161 return operators
163 sub_basis = get_logical_basis(n_qubits)
164 basis = []
166 sub_basis_size = sub_basis.bdims[-1]
168 for i in range(4):
169 for j in range(sub_basis_size):
170 basis.append(operators[i] ^ sub_basis[j])
172 return Qarray.from_list(basis)