Coverage for jaxquantum/core/operators.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""States.""" 

2 

3from typing import List 

4from jax import config 

5from math import prod 

6 

7import jax.numpy as jnp 

8from jax.nn import one_hot 

9 

10from jaxquantum.core.qarray import Qarray, tensor 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15def sigmax() -> Qarray: 

16 """σx 

17 

18 Returns: 

19 σx Pauli Operator 

20 """ 

21 return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]])) 

22 

23 

24def sigmay() -> Qarray: 

25 """σy 

26 

27 Returns: 

28 σy Pauli Operator 

29 """ 

30 return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]])) 

31 

32 

33def sigmaz() -> Qarray: 

34 """σz 

35 

36 Returns: 

37 σz Pauli Operator 

38 """ 

39 return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]])) 

40 

41 

42def hadamard() -> Qarray: 

43 """H 

44 

45 Returns: 

46 H: Hadamard gate 

47 """ 

48 return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2)) 

49 

50 

51def sigmam() -> Qarray: 

52 """σ- 

53 

54 Returns: 

55 σ- Pauli Operator 

56 """ 

57 return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]])) 

58 

59 

60def sigmap() -> Qarray: 

61 """σ+ 

62 

63 Returns: 

64 σ+ Pauli Operator 

65 """ 

66 return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]])) 

67 

68 

69def qubit_rotation(theta: float, nx, ny, nz) -> Qarray: 

70 """Single qubit rotation. 

71 

72 Args: 

73 theta: rotation angle. 

74 nx: rotation axis x component. 

75 ny: rotation axis y component. 

76 nz: rotation axis z component. 

77 

78 Returns: 

79 Single qubit rotation operator. 

80 """ 

81 return jnp.cos(theta / 2) * identity(2) - 1j * jnp.sin(theta / 2) * ( 

82 nx * sigmax() + ny * sigmay() + nz * sigmaz() 

83 ) 

84 

85 

86def destroy(N) -> Qarray: 

87 """annihilation operator 

88 

89 Args: 

90 N: Hilbert space size 

91 

92 Returns: 

93 annilation operator in Hilber Space of size N 

94 """ 

95 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1)) 

96 

97 

98def create(N) -> Qarray: 

99 """creation operator 

100 

101 Args: 

102 N: Hilbert space size 

103 

104 Returns: 

105 creation operator in Hilber Space of size N 

106 """ 

107 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1)) 

108 

109 

110def num(N) -> Qarray: 

111 """Number operator 

112 

113 Args: 

114 N: Hilbert Space size 

115 

116 Returns: 

117 number operator in Hilber Space of size N 

118 """ 

119 return Qarray.create(jnp.diag(jnp.arange(N))) 

120 

121 

122def identity(*args, **kwargs) -> Qarray: 

123 """Identity matrix. 

124 

125 Returns: 

126 Identity matrix. 

127 """ 

128 return Qarray.create(jnp.eye(*args, **kwargs)) 

129 

130 

131def identity_like(A) -> Qarray: 

132 """Identity matrix with the same shape as A. 

133 

134 Args: 

135 A: Matrix. 

136 

137 Returns: 

138 Identity matrix with the same shape as A. 

139 """ 

140 space_dims = A.space_dims 

141 total_dim = prod(space_dims) 

142 return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims]) 

143 

144 

145def displace(N, α) -> Qarray: 

146 """Displacement operator 

147 

148 Args: 

149 N: Hilbert Space Size 

150 α: Phase space displacement 

151 

152 Returns: 

153 Displace operator D(α) 

154 """ 

155 a = destroy(N) 

156 return (α * a.dag() - jnp.conj(α) * a).expm() 

157 

158 

159# States --------------------------------------------------------------------- 

160 

161 

162def basis(N: int, k: int): 

163 """Creates a |k> (i.e. fock state) ket in a specified Hilbert Space. 

164 

165 Args: 

166 N: Hilbert space dimension 

167 k: fock number 

168 

169 Returns: 

170 Fock State |k> 

171 """ 

172 return Qarray.create(one_hot(k, N).reshape(N, 1)) 

173 

174 

175def coherent(N: int, α: complex) -> Qarray: 

176 """Coherent state. 

177 

178 Args: 

179 N: Hilbert Space Size. 

180 α: coherent state amplitude. 

181 

182 Return: 

183 Coherent state |α⟩. 

184 """ 

185 return displace(N, α) @ basis(N, 0) 

186 

187 

188def thermal(N: int, beta: float) -> Qarray: 

189 """Thermal state. 

190 

191 Args: 

192 N: Hilbert Space Size. 

193 beta: thermal state inverse temperature. 

194 

195 Return: 

196 Thermal state. 

197 """ 

198 

199 return Qarray.create( 

200 jnp.where( 

201 jnp.isposinf(beta), 

202 basis(N, 0).to_dm().data, 

203 jnp.diag(jnp.exp(-beta * jnp.linspace(0, N - 1, N))), 

204 ) 

205 ).unit() 

206 

207 

208def basis_like(A: Qarray, ks: List[int]) -> Qarray: 

209 """Creates a |k> (i.e. fock state) ket with the same space dims as A. 

210 

211 Args: 

212 A: state or operator. 

213 k: fock number. 

214 

215 Returns: 

216 Fock State |k> with the same space dims as A. 

217 """ 

218 space_dims = A.space_dims 

219 assert len(space_dims) == len(ks), "len(ks) must be equal to len(space_dims)" 

220 

221 kets = [] 

222 for j, k in enumerate(ks): 

223 kets.append(basis(space_dims[j], k)) 

224 return tensor(*kets)