Coverage for jaxquantum/circuits/library/qubit.py: 20%

105 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1"""qubit gates.""" 

2 

3from jaxquantum.core.operators import ( 

4 identity, 

5 sigmax, 

6 sigmay, 

7 sigmaz, 

8 basis, 

9 hadamard, 

10 qubit_rotation, 

11) 

12from jaxquantum.circuits.gates import Gate 

13from jaxquantum.core.qarray import Qarray 

14import jax.numpy as jnp 

15 

16 

17def X(): 

18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1) 

19 

20 

21def Y(): 

22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1) 

23 

24 

25def Z(): 

26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1) 

27 

28 

29def H(): 

30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1) 

31 

32 

33def Rx(theta, ts=None): 

34 gen_Ht = None 

35 if ts is not None: 

36 delta_t = ts[-1] - ts[0] 

37 amp = theta / delta_t 

38 gen_Ht = lambda params: (lambda t: amp / 2 * sigmax()) 

39 

40 return Gate.create( 

41 2, 

42 name="Rx", 

43 params={"theta": theta}, 

44 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0), 

45 gen_Ht=gen_Ht, 

46 ts=ts, 

47 num_modes=1, 

48 ) 

49 

50 

51def Ry(theta, ts=None): 

52 gen_Ht = None 

53 if ts is not None: 

54 delta_t = ts[-1] - ts[0] 

55 amp = theta / delta_t 

56 gen_Ht = lambda params: (lambda t: amp / 2 * sigmay()) 

57 return Gate.create( 

58 2, 

59 name="Ry", 

60 params={"theta": theta}, 

61 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0), 

62 gen_Ht=gen_Ht, 

63 ts=ts, 

64 num_modes=1, 

65 ) 

66 

67 

68def Rz(theta, ts=None): 

69 gen_Ht = None 

70 if ts is not None: 

71 delta_t = ts[-1] - ts[0] 

72 amp = theta / delta_t 

73 gen_Ht = lambda params: (lambda t: amp / 2 * sigmaz()) 

74 return Gate.create( 

75 2, 

76 name="Rz", 

77 params={"theta": theta}, 

78 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1), 

79 gen_Ht=gen_Ht, 

80 ts=ts, 

81 num_modes=1, 

82 ) 

83 

84 

85def MZ(measure=None): 

86 g = basis(2, 0) 

87 e = basis(2, 1) 

88 

89 gg = g @ g.dag() 

90 ee = e @ e.dag() 

91 

92 if measure is None: 

93 kmap = Qarray.from_list([gg, ee]) 

94 gate_name = "MZ" 

95 elif measure == +1: 

96 kmap = Qarray.from_list([gg]) 

97 gate_name = "MZ_plus" 

98 elif measure == -1: 

99 kmap = Qarray.from_list([ee]) 

100 gate_name = "MZ_minus" 

101 else: 

102 raise ValueError("measure should be None, +1 or -1") 

103 

104 return Gate.create(2, name=gate_name, gen_KM=lambda params: kmap, num_modes=1) 

105 

106 

107def MX(measure=None): 

108 g = basis(2, 0) 

109 e = basis(2, 1) 

110 

111 plus = (g + e).unit() 

112 minus = (g - e).unit() 

113 

114 pp = plus @ plus.dag() 

115 mm = minus @ minus.dag() 

116 

117 if measure is None: 

118 kmap = Qarray.from_list([pp, mm]) 

119 gate_name = "MX" 

120 elif measure == +1: 

121 kmap = Qarray.from_list([pp]) 

122 gate_name = "MX_plus" 

123 elif measure == -1: 

124 kmap = Qarray.from_list([mm]) 

125 gate_name = "MX_minus" 

126 else: 

127 raise ValueError("measure should be None, +1 or -1") 

128 

129 return Gate.create(2, name=gate_name, gen_KM=lambda params: kmap, num_modes=1) 

130 

131 

132def Reset(): 

133 g = basis(2, 0) 

134 e = basis(2, 1) 

135 

136 gg = g @ g.dag() 

137 ge = g @ e.dag() 

138 

139 kmap = Qarray.from_list([gg, ge]) 

140 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1) 

141 

142 

143def IP_Reset(p_eg, p_ee): 

144 g = basis(2, 0) 

145 e = basis(2, 1) 

146 

147 gg = g @ g.dag() 

148 ge = g @ e.dag() 

149 eg = e @ g.dag() 

150 ee = e @ e.dag() 

151 

152 k_0 = jnp.sqrt(1 - p_eg) * gg 

153 k_1 = jnp.sqrt(p_ee) * ee 

154 k_2 = jnp.sqrt(p_eg) * eg 

155 k_3 = jnp.sqrt(1 - p_ee) * ge 

156 

157 kmap = Qarray.from_list([k_0, k_1, k_2, k_3]) 

158 

159 return Gate.create( 

160 2, 

161 name="IP_Reset", 

162 params={"p_eg": p_eg, "p_ge": p_ee}, 

163 gen_KM=lambda params: kmap, 

164 num_modes=1, 

165 ) 

166 

167 

168def CX(): 

169 g = basis(2, 0) 

170 e = basis(2, 1) 

171 

172 gg = g @ g.dag() 

173 ee = e @ e.dag() 

174 

175 op = (gg ^ identity(2)) + (ee ^ sigmax()) 

176 

177 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2) 

178 

179 

180def _Thermal_Kraus_Ops_Qb(err_prob, n_bar): 

181 """ " Returns the Kraus Operators for a thermal channel with probability 

182 err_prob and average photon number n_bar in a Hilbert Space of size 2""" 

183 p = n_bar / (n_bar + 1) 

184 return [ 

185 Qarray.create( 

186 jnp.sqrt(1 - p) * jnp.array([[1, 0], [0, jnp.sqrt(1 - err_prob)]]) 

187 ), 

188 Qarray.create(jnp.sqrt(1 - p) * jnp.array([[0, jnp.sqrt(err_prob)], [0, 0]])), 

189 Qarray.create(jnp.sqrt(p) * jnp.array([[0, 0], [jnp.sqrt(err_prob), 0]])), 

190 Qarray.create(jnp.sqrt(p) * jnp.array([[jnp.sqrt(1 - err_prob), 0], [0, 1]])), 

191 ] 

192 

193 

194def Thermal_Ch_Qb(err_prob, n_bar): 

195 kmap = lambda params: Qarray.from_list(_Thermal_Kraus_Ops_Qb(err_prob, 

196 n_bar)) 

197 return Gate.create( 

198 2, 

199 name="Thermal_Ch_Qb", 

200 params={"err_prob": err_prob, "n_bar": n_bar}, 

201 gen_KM=kmap, 

202 num_modes=1, 

203 ) 

204 

205 

206def _Pure_Dephasing_Ops_Qb(err_prob): 

207 """ " Returns the Kraus Operators for a thermal channel with probability 

208 err_prob and average photon number n_bar in a Hilbert Space of size 2""" 

209 return [ 

210 jnp.sqrt(1-err_prob)*identity(2), 

211 jnp.sqrt(err_prob)*sigmaz() 

212 ] 

213 

214 

215def Dephasing_Ch_Qb(err_prob): 

216 kmap = lambda params: Qarray.from_list(_Pure_Dephasing_Ops_Qb(err_prob)) 

217 return Gate.create( 

218 2, 

219 name="Dephasing_Ch_Qb", 

220 params={"err_prob": err_prob}, 

221 gen_KM=kmap, 

222 num_modes=1, 

223 )