Coverage for jaxquantum/circuits/library/qubit.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""qubit gates.""" 

2 

3from jaxquantum.core.operators import ( 

4 identity, 

5 sigmax, 

6 sigmay, 

7 sigmaz, 

8 basis, 

9 hadamard, 

10 qubit_rotation, 

11) 

12from jaxquantum.circuits.gates import Gate 

13from jaxquantum.core.qarray import Qarray 

14import jax.numpy as jnp 

15 

16 

17def X(): 

18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1) 

19 

20 

21def Y(): 

22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1) 

23 

24 

25def Z(): 

26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1) 

27 

28 

29def H(): 

30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1) 

31 

32 

33def Rx(theta): 

34 return Gate.create( 

35 2, 

36 name="Rx", 

37 params={"theta": theta}, 

38 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0), 

39 num_modes=1, 

40 ) 

41 

42 

43def Ry(theta): 

44 return Gate.create( 

45 2, 

46 name="Ry", 

47 params={"theta": theta}, 

48 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0), 

49 num_modes=1, 

50 ) 

51 

52 

53def Rz(theta): 

54 return Gate.create( 

55 2, 

56 name="Rz", 

57 params={"theta": theta}, 

58 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1), 

59 num_modes=1, 

60 ) 

61 

62 

63def MZ(): 

64 g = basis(2, 0) 

65 e = basis(2, 1) 

66 

67 gg = g @ g.dag() 

68 ee = e @ e.dag() 

69 

70 kmap = Qarray.from_list([gg, ee]) 

71 

72 return Gate.create(2, name="MZ", gen_KM=lambda params: kmap, num_modes=1) 

73 

74 

75def MX(): 

76 g = basis(2, 0) 

77 e = basis(2, 1) 

78 

79 plus = (g + e).unit() 

80 minus = (g - e).unit() 

81 

82 pp = plus @ plus.dag() 

83 mm = minus @ minus.dag() 

84 

85 kmap = Qarray.from_list([pp, mm]) 

86 

87 return Gate.create(2, name="MX", gen_KM=lambda params: kmap, num_modes=1) 

88 

89 

90def MX_plus(): 

91 g = basis(2, 0) 

92 e = basis(2, 1) 

93 plus = (g + e).unit() 

94 pp = plus @ plus.dag() 

95 kmap = Qarray.from_list([2 * pp]) 

96 

97 return Gate.create(2, name="MXplus", gen_KM=lambda params: kmap, num_modes=1) 

98 

99 

100def MZ_plus(): 

101 g = basis(2, 0) 

102 plus = g 

103 pp = plus @ plus.dag() 

104 kmap = Qarray.from_list([2 * pp]) 

105 

106 return Gate.create(2, name="MZplus", gen_KM=lambda params: kmap, num_modes=1) 

107 

108 

109def Reset(): 

110 g = basis(2, 0) 

111 e = basis(2, 1) 

112 

113 gg = g @ g.dag() 

114 ge = g @ e.dag() 

115 

116 kmap = Qarray.from_list([gg, ge]) 

117 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1) 

118 

119 

120def IP_Reset(p_eg, p_ee): 

121 g = basis(2, 0) 

122 e = basis(2, 1) 

123 

124 gg = g @ g.dag() 

125 ge = g @ e.dag() 

126 eg = e @ g.dag() 

127 ee = e @ e.dag() 

128 

129 k_0 = jnp.sqrt(1 - p_eg) * gg + jnp.sqrt(p_eg) * eg 

130 k_1 = jnp.sqrt(p_ee) * ee + jnp.sqrt(1 - p_ee) * ge 

131 

132 kmap = Qarray.from_list([k_0, k_1]) 

133 

134 return Gate.create( 

135 2, 

136 name="IP_Reset", 

137 params={"p_eg": p_eg, "p_ge": p_ee}, 

138 gen_KM=lambda params: kmap, 

139 num_modes=1, 

140 ) 

141 

142 

143def CX(): 

144 g = basis(2, 0) 

145 e = basis(2, 1) 

146 

147 gg = g @ g.dag() 

148 ee = e @ e.dag() 

149 

150 op = (gg ^ identity(2)) + (ee ^ sigmax()) 

151 

152 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2)