Coverage for jaxquantum/core/measurements.py: 35%

66 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Helpers.""" 

2 

3from typing import List 

4from jax import config, Array 

5 

6import jax.numpy as jnp 

7from tqdm import tqdm 

8 

9from jaxquantum.core.qarray import Qarray, powm 

10from jaxquantum.core.operators import identity, sigmax, sigmay, sigmaz 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15# Calculations ---------------------------------------------------------------- 

16 

17 

18def overlap(rho: Qarray, sigma: Qarray) -> Array: 

19 """Overlap between two states or operators. 

20 

21 Args: 

22 rho: state/operator. 

23 sigma: state/operator. 

24 

25 Returns: 

26 Overlap between rho and sigma. 

27 """ 

28 

29 if rho.is_vec() and sigma.is_vec(): 

30 return jnp.abs(((rho.to_ket().dag() @ sigma.to_ket()).trace())) ** 2 

31 elif rho.is_vec(): 

32 rho = rho.to_ket() 

33 res = (rho.dag() @ sigma @ rho).data 

34 return res.squeeze(-1).squeeze(-1) 

35 elif sigma.is_vec(): 

36 sigma = sigma.to_ket() 

37 res = (sigma.dag() @ rho @ sigma).data 

38 return res.squeeze(-1).squeeze(-1) 

39 else: 

40 return (rho.dag() @ sigma).trace() 

41 

42 

43def fidelity(rho: Qarray, sigma: Qarray) -> float: 

44 """Fidelity between two states. 

45 

46 Args: 

47 rho: state. 

48 sigma: state. 

49 

50 Returns: 

51 Fidelity between rho and sigma. 

52 """ 

53 rho = rho.to_dm() 

54 sigma = sigma.to_dm() 

55 

56 sqrt_rho = powm(rho, 0.5) 

57 

58 return ((powm(sqrt_rho @ sigma @ sqrt_rho, 0.5)).tr()) ** 2 

59 

60 

61def quantum_state_tomography( 

62 rho: Qarray, physical_basis: Qarray, logical_basis: Qarray 

63) -> Qarray: 

64 """Perform quantum state tomography to retrieve the density matrix in 

65 the logical basis. 

66 

67 Args: 

68 rho: state expressed in the physical Hilbert space basis. 

69 physical_basis: list of logical operators expressed in the physical 

70 Hilbert space basis forming a complete logical operator basis. 

71 logical_basis: list of logical operators expressed in the 

72 logical Hilbert space basis forming a complete operator basis. 

73 

74 

75 Returns: 

76 Density matrix of state rho expressed in the logical basis. 

77 """ 

78 dm = jnp.zeros_like(logical_basis[0].data) 

79 rho = rho.to_dm() 

80 

81 if physical_basis.bdims[-1] != logical_basis.bdims[-1]: 

82 raise ValueError( 

83 f"The two bases should have the same size for the " 

84 f"last batch dimension. Received " 

85 f"{physical_basis.bdims} and {logical_basis.bdims} " 

86 f"instead." 

87 ) 

88 

89 space_size = physical_basis.bdims[-1] 

90 

91 for i in tqdm(range(space_size), total=space_size): 

92 p_i = (rho @ physical_basis[i]).trace() 

93 dm += p_i * logical_basis[i].data 

94 

95 return Qarray.create(dm, dims=logical_basis.dims, bdims=physical_basis[0].bdims) 

96 

97 

98def get_physical_basis(qubits: List) -> Qarray: 

99 """Compute a complete operator basis of a QEC code on a 

100 physical system specified by a number of qubits. 

101 

102 Args: 

103 qubits: list of qubit codes, must have 

104 common_gates and params attributes. 

105 

106 Returns: 

107 List containing the complete operator basis. 

108 """ 

109 

110 qubit = qubits[0] 

111 qubits = qubits[1:] 

112 try: 

113 operators = Qarray.from_list( 

114 [ 

115 identity(qubit.params["N"]), 

116 qubit.common_gates["X"], 

117 qubit.common_gates["Y"], 

118 qubit.common_gates["Z"], 

119 ] 

120 ) 

121 except KeyError: 

122 print("QEC code must have common_gates for all three axes.") 

123 except AttributeError: 

124 print("QEC code must have common_gates and params attribute.") 

125 

126 if len(qubits) == 0: 

127 return operators 

128 

129 sub_basis = get_physical_basis(qubits) 

130 basis = [] 

131 

132 sub_basis_size = sub_basis.bdims[-1] 

133 

134 for i in range(4): 

135 for j in range(sub_basis_size): 

136 basis.append(operators[i] ^ sub_basis[j]) 

137 

138 return Qarray.from_list(basis) 

139 

140 

141def get_logical_basis(n_qubits: int) -> Qarray: 

142 """Compute a complete operator basis of a system composed of logical 

143 qubits. 

144 

145 Args: 

146 n_qubits: number of qubits 

147 

148 Returns: 

149 List containing the complete operator basis. 

150 """ 

151 if n_qubits < 1: 

152 raise ValueError("n_qubits must be at least 1.") 

153 

154 n_qubits -= 1 

155 

156 operators = Qarray.from_list( 

157 [identity(2) / 2, sigmax() / 2, sigmay() / 2, sigmaz() / 2] 

158 ) 

159 

160 if n_qubits == 0: 

161 return operators 

162 

163 sub_basis = get_logical_basis(n_qubits) 

164 basis = [] 

165 

166 sub_basis_size = sub_basis.bdims[-1] 

167 

168 for i in range(4): 

169 for j in range(sub_basis_size): 

170 basis.append(operators[i] ^ sub_basis[j]) 

171 

172 return Qarray.from_list(basis)