Coverage for jaxquantum / circuits / library / qubit.py: 20%

106 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 20:38 +0000

1"""qubit gates.""" 

2 

3from jaxquantum.core.operators import ( 

4 identity, 

5 sigmax, 

6 sigmay, 

7 sigmaz, 

8 basis, 

9 hadamard, 

10 qubit_rotation, 

11) 

12from jaxquantum.circuits.gates import Gate 

13from jaxquantum.core.qarray import Qarray 

14import jax.numpy as jnp 

15 

16 

17def X(): 

18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1) 

19 

20 

21def Y(): 

22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1) 

23 

24 

25def Z(): 

26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1) 

27 

28 

29def H(): 

30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1) 

31 

32 

33def Rx(theta, ts=None): 

34 gen_Ht = None 

35 if ts is not None: 

36 delta_t = ts[-1] - ts[0] 

37 amp = theta / delta_t 

38 gen_Ht = lambda params: (lambda t: amp / 2 * sigmax()) 

39 

40 return Gate.create( 

41 2, 

42 name="Rx", 

43 params={"theta": theta}, 

44 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0), 

45 gen_Ht=gen_Ht, 

46 ts=ts, 

47 num_modes=1, 

48 ) 

49 

50 

51def Ry(theta, ts=None): 

52 gen_Ht = None 

53 if ts is not None: 

54 delta_t = ts[-1] - ts[0] 

55 amp = theta / delta_t 

56 gen_Ht = lambda params: (lambda t: amp / 2 * sigmay()) 

57 return Gate.create( 

58 2, 

59 name="Ry", 

60 params={"theta": theta}, 

61 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0), 

62 gen_Ht=gen_Ht, 

63 ts=ts, 

64 num_modes=1, 

65 ) 

66 

67 

68def Rz(theta, ts=None): 

69 gen_Ht = None 

70 if ts is not None: 

71 delta_t = ts[-1] - ts[0] 

72 amp = theta / delta_t 

73 gen_Ht = lambda params: (lambda t: amp / 2 * sigmaz()) 

74 return Gate.create( 

75 2, 

76 name="Rz", 

77 params={"theta": theta}, 

78 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1), 

79 gen_Ht=gen_Ht, 

80 ts=ts, 

81 num_modes=1, 

82 ) 

83 

84 

85def MZ(measure=None): 

86 g = basis(2, 0) 

87 e = basis(2, 1) 

88 

89 gg = g @ g.dag() 

90 ee = e @ e.dag() 

91 

92 if measure is None: 

93 kmap = Qarray.from_list([gg, ee]) 

94 gate_name = "MZ" 

95 elif measure == +1: 

96 kmap = Qarray.from_list([gg]) 

97 gate_name = "MZ_plus" 

98 elif measure == -1: 

99 kmap = Qarray.from_list([ee]) 

100 gate_name = "MZ_minus" 

101 else: 

102 raise ValueError("measure should be None, +1 or -1") 

103 

104 return Gate.create(2, name=gate_name, gen_KM=lambda params: kmap, num_modes=1) 

105 

106 

107def MX(measure=None): 

108 g = basis(2, 0) 

109 e = basis(2, 1) 

110 

111 plus = (g + e).unit() 

112 minus = (g - e).unit() 

113 

114 pp = plus @ plus.dag() 

115 mm = minus @ minus.dag() 

116 

117 if measure is None: 

118 kmap = Qarray.from_list([pp, mm]) 

119 gate_name = "MX" 

120 elif measure == +1: 

121 kmap = Qarray.from_list([pp]) 

122 gate_name = "MX_plus" 

123 elif measure == -1: 

124 kmap = Qarray.from_list([mm]) 

125 gate_name = "MX_minus" 

126 else: 

127 raise ValueError("measure should be None, +1 or -1") 

128 

129 return Gate.create(2, name=gate_name, gen_KM=lambda params: kmap, num_modes=1) 

130 

131 

132def Reset(): 

133 g = basis(2, 0) 

134 e = basis(2, 1) 

135 

136 gg = g @ g.dag() 

137 ge = g @ e.dag() 

138 

139 kmap = Qarray.from_list([gg, ge]) 

140 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1) 

141 

142 

143def IP_Reset(p_eg, p_ee): 

144 g = basis(2, 0) 

145 e = basis(2, 1) 

146 

147 gg = g @ g.dag() 

148 ge = g @ e.dag() 

149 eg = e @ g.dag() 

150 ee = e @ e.dag() 

151 

152 k_0 = jnp.sqrt(1 - p_eg) * gg 

153 k_1 = jnp.sqrt(p_ee) * ee 

154 k_2 = jnp.sqrt(p_eg) * eg 

155 k_3 = jnp.sqrt(1 - p_ee) * ge 

156 

157 kmap = Qarray.from_list([k_0, k_1, k_2, k_3]) 

158 

159 return Gate.create( 

160 2, 

161 name="IP_Reset", 

162 params={"p_eg": p_eg, "p_ge": p_ee}, 

163 gen_KM=lambda params: kmap, 

164 num_modes=1, 

165 ) 

166 

167 

168def CX(): 

169 g = basis(2, 0) 

170 e = basis(2, 1) 

171 

172 gg = g @ g.dag() 

173 ee = e @ e.dag() 

174 

175 op = (gg ^ identity(2)) + (ee ^ sigmax()) 

176 

177 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2) 

178 

179 

180def _Thermal_Kraus_Ops_Qb(err_prob, n_bar): 

181 """ " Returns the Kraus Operators for a thermal channel with probability 

182 err_prob and average photon number n_bar in a Hilbert Space of size 2""" 

183 p0 = (n_bar + 1) / (2*n_bar + 1) 

184 p1 = n_bar / (2*n_bar + 1) 

185 return [ 

186 Qarray.create( 

187 jnp.sqrt(p0) * jnp.array([[1, 0], 

188 [0, jnp.sqrt(1 - err_prob)]])), 

189 

190 Qarray.create(jnp.sqrt(p0) * jnp.array([[0, jnp.sqrt(err_prob)], 

191 [0, 0]])), 

192 

193 Qarray.create(jnp.sqrt(p1) * jnp.array([[0, 0], 

194 [jnp.sqrt(err_prob), 0]])), 

195 

196 Qarray.create(jnp.sqrt(p1) * jnp.array([[jnp.sqrt(1 - err_prob), 0], 

197 [0, 1]])), 

198 ] 

199 

200 

201def Thermal_Ch_Qb(err_prob, n_bar): 

202 kmap = lambda params: Qarray.from_list(_Thermal_Kraus_Ops_Qb(err_prob, 

203 n_bar)) 

204 return Gate.create( 

205 2, 

206 name="Thermal_Ch_Qb", 

207 params={"err_prob": err_prob, "n_bar": n_bar}, 

208 gen_KM=kmap, 

209 num_modes=1, 

210 ) 

211 

212 

213def _Pure_Dephasing_Ops_Qb(err_prob): 

214 """ " Returns the Kraus Operators for a thermal channel with probability 

215 err_prob and average photon number n_bar in a Hilbert Space of size 2""" 

216 return [ 

217 jnp.sqrt(1-err_prob)*identity(2), 

218 jnp.sqrt(err_prob)*sigmaz() 

219 ] 

220 

221 

222def Dephasing_Ch_Qb(err_prob): 

223 kmap = lambda params: Qarray.from_list(_Pure_Dephasing_Ops_Qb(err_prob)) 

224 return Gate.create( 

225 2, 

226 name="Dephasing_Ch_Qb", 

227 params={"err_prob": err_prob}, 

228 gen_KM=kmap, 

229 num_modes=1, 

230 )