Coverage for jaxquantum/circuits/library/qubit.py: 22%

87 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""qubit gates.""" 

2 

3from jaxquantum.core.operators import ( 

4 identity, 

5 sigmax, 

6 sigmay, 

7 sigmaz, 

8 basis, 

9 hadamard, 

10 qubit_rotation, 

11) 

12from jaxquantum.circuits.gates import Gate 

13from jaxquantum.core.qarray import Qarray 

14import jax.numpy as jnp 

15 

16 

17def X(): 

18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1) 

19 

20 

21def Y(): 

22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1) 

23 

24 

25def Z(): 

26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1) 

27 

28 

29def H(): 

30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1) 

31 

32 

33def Rx(theta, ts=None): 

34 

35 gen_Ht = None 

36 if ts is not None: 

37 delta_t = ts[-1] - ts[0] 

38 amp = theta / delta_t 

39 gen_Ht = lambda params: ( 

40 lambda t: amp / 2 * sigmax()) 

41 

42 return Gate.create( 

43 2, 

44 name="Rx", 

45 params={"theta": theta}, 

46 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0), 

47 gen_Ht=gen_Ht, 

48 ts=ts, 

49 num_modes=1, 

50 ) 

51 

52 

53def Ry(theta, ts=None): 

54 gen_Ht = None 

55 if ts is not None: 

56 delta_t = ts[-1] - ts[0] 

57 amp = theta / delta_t 

58 gen_Ht = lambda params: ( 

59 lambda t: amp / 2 * sigmay()) 

60 return Gate.create( 

61 2, 

62 name="Ry", 

63 params={"theta": theta}, 

64 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0), 

65 gen_Ht=gen_Ht, 

66 ts=ts, 

67 num_modes=1, 

68 ) 

69 

70 

71def Rz(theta, ts=None): 

72 gen_Ht = None 

73 if ts is not None: 

74 delta_t = ts[-1] - ts[0] 

75 amp = theta / delta_t 

76 gen_Ht = lambda params: ( 

77 lambda t: amp / 2 * sigmaz()) 

78 return Gate.create( 

79 2, 

80 name="Rz", 

81 params={"theta": theta}, 

82 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1), 

83 gen_Ht=gen_Ht, 

84 ts=ts, 

85 num_modes=1, 

86 ) 

87 

88 

89def MZ(): 

90 g = basis(2, 0) 

91 e = basis(2, 1) 

92 

93 gg = g @ g.dag() 

94 ee = e @ e.dag() 

95 

96 kmap = Qarray.from_list([gg, ee]) 

97 

98 return Gate.create(2, name="MZ", gen_KM=lambda params: kmap, num_modes=1) 

99 

100 

101def MX(): 

102 g = basis(2, 0) 

103 e = basis(2, 1) 

104 

105 plus = (g + e).unit() 

106 minus = (g - e).unit() 

107 

108 pp = plus @ plus.dag() 

109 mm = minus @ minus.dag() 

110 

111 kmap = Qarray.from_list([pp, mm]) 

112 

113 return Gate.create(2, name="MX", gen_KM=lambda params: kmap, num_modes=1) 

114 

115 

116def MX_plus(): 

117 g = basis(2, 0) 

118 e = basis(2, 1) 

119 plus = (g + e).unit() 

120 pp = plus @ plus.dag() 

121 kmap = Qarray.from_list([pp]) 

122 

123 return Gate.create(2, name="MXplus", gen_KM=lambda params: kmap, num_modes=1) 

124 

125 

126def MZ_plus(): 

127 g = basis(2, 0) 

128 plus = g 

129 pp = plus @ plus.dag() 

130 kmap = Qarray.from_list([pp]) 

131 

132 return Gate.create(2, name="MZplus", gen_KM=lambda params: kmap, num_modes=1) 

133 

134 

135def Reset(): 

136 g = basis(2, 0) 

137 e = basis(2, 1) 

138 

139 gg = g @ g.dag() 

140 ge = g @ e.dag() 

141 

142 kmap = Qarray.from_list([gg, ge]) 

143 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1) 

144 

145 

146def IP_Reset(p_eg, p_ee): 

147 g = basis(2, 0) 

148 e = basis(2, 1) 

149 

150 gg = g @ g.dag() 

151 ge = g @ e.dag() 

152 eg = e @ g.dag() 

153 ee = e @ e.dag() 

154 

155 k_0 = jnp.sqrt(1 - p_eg) * gg + jnp.sqrt(p_eg) * eg 

156 k_1 = jnp.sqrt(p_ee) * ee + jnp.sqrt(1 - p_ee) * ge 

157 

158 kmap = Qarray.from_list([k_0, k_1]) 

159 

160 return Gate.create( 

161 2, 

162 name="IP_Reset", 

163 params={"p_eg": p_eg, "p_ge": p_ee}, 

164 gen_KM=lambda params: kmap, 

165 num_modes=1, 

166 ) 

167 

168 

169def CX(): 

170 g = basis(2, 0) 

171 e = basis(2, 1) 

172 

173 gg = g @ g.dag() 

174 ee = e @ e.dag() 

175 

176 op = (gg ^ identity(2)) + (ee ^ sigmax()) 

177 

178 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2)