Coverage for jaxquantum/circuits/circuits.py: 84%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Circuits. 

2 

3Inspired by a mix of Cirq and Qiskit circuits. 

4""" 

5 

6from flax import struct 

7from jax import config 

8from typing import List, Optional, Union 

9from copy import deepcopy 

10from numpy import argsort 

11import jax.numpy as jnp 

12 

13from jaxquantum.core.operators import identity 

14from jaxquantum.circuits.gates import Gate 

15from jaxquantum.circuits.constants import SimulateMode 

16from jaxquantum.core.qarray import Qarray, concatenate 

17 

18 

19config.update("jax_enable_x64", True) 

20 

21 

22@struct.dataclass 

23class Register: 

24 dims: List[int] = struct.field(pytree_node=False) 

25 

26 @classmethod 

27 def create(cls, dims: List[int]): 

28 return Register(dims=dims) 

29 

30 def __eq__(self, other): 

31 if isinstance(other, Register): 

32 return self.dims == other.dims 

33 return False 

34 

35 

36@struct.dataclass 

37class Operation: 

38 gate: Gate 

39 indices: List[int] = struct.field(pytree_node=False) 

40 register: Register 

41 

42 @classmethod 

43 def create(cls, gate: Gate, indices: Union[int, List[int]], register: Register): 

44 if isinstance(indices, int): 

45 indices = [indices] 

46 

47 assert gate.num_modes == len(indices), ( 

48 "Number of indices must match gate's num_modes." 

49 ) 

50 assert gate.dims == [register.dims[ind] for ind in indices], ( 

51 "Indices must match register dimensions." 

52 ) 

53 

54 if any( 

55 (0 > ind and ind >= len(register.dims)) or not isinstance(ind, int) 

56 for ind in indices 

57 ): 

58 raise ValueError("Indices must be integers within the register.") 

59 

60 return Operation(gate=gate, indices=indices, register=register) 

61 

62 

63 def promote(self, op: Qarray) -> Qarray: 

64 indices_order = self.indices 

65 missing_indices = [ 

66 i for i in range(len(self.register.dims)) if i not in indices_order 

67 ] 

68 for j in missing_indices: 

69 op = op ^ identity(self.register.dims[j]) 

70 combined_indices = indices_order + missing_indices 

71 sorted_ind = list(argsort(combined_indices)) 

72 op = op.transpose(sorted_ind) 

73 return op 

74 

75 

76@struct.dataclass 

77class Layer: 

78 operations: List[Operation] = struct.field(pytree_node=False) 

79 _unique_indices: List[int] = struct.field(pytree_node=False) 

80 _default_simulate_mode: SimulateMode = struct.field(pytree_node=False) 

81 

82 @classmethod 

83 def create( 

84 cls, operations: List[Operation], default_simulate_mode=SimulateMode.UNITARY 

85 ): 

86 all_indices = [ind for op in operations for ind in op.indices] 

87 unique_indices = list(set(all_indices)) 

88 

89 if default_simulate_mode != SimulateMode.HAMILTONIAN: 

90 if len(all_indices) != len(unique_indices): 

91 raise ValueError("Operations must not have overlapping indices.") 

92 

93 return Layer( 

94 operations=operations, 

95 _unique_indices=unique_indices, 

96 _default_simulate_mode=default_simulate_mode, 

97 ) 

98 

99 def add(self, operation: Operation): 

100 if self._default_simulate_mode != SimulateMode.HAMILTONIAN: 

101 if any(ind in self._unique_indices for ind in operation.indices): 

102 raise ValueError("Operations must not have overlapping indices.") 

103 self.operations.append(operation) 

104 self._unique_indices.extend(operation.indices) 

105 

106 def gen_U(self): 

107 U = None 

108 

109 if len(self.operations) == 0: 

110 return None 

111 

112 indices_order = [] 

113 for operation in self.operations: 

114 indices_order += operation.indices 

115 

116 if U is None: 

117 U = operation.gate.U 

118 else: 

119 U = U ^ operation.gate.U 

120 

121 register = self.operations[0].register 

122 missing_indices = [ 

123 i for i in range(len(register.dims)) if i not in indices_order 

124 ] 

125 

126 for j in missing_indices: 

127 U = U ^ identity(register.dims[j]) 

128 

129 combined_indices = indices_order + missing_indices 

130 

131 sorted_ind = list(argsort(combined_indices)) 

132 U = U.transpose(sorted_ind) 

133 return U 

134 

135 def gen_Ht(self): 

136 Ht = lambda t: 0 

137 

138 if len(self.operations) == 0: 

139 return Ht 

140 

141 for operation in self.operations: 

142 def Ht(t, prev_Ht=Ht, prev_operation=operation): 

143 return prev_Ht(t) + prev_operation.promote(prev_operation.gate.Ht(t)) 

144 

145 return Ht 

146 

147 def gen_KM(self): 

148 KM = Qarray.from_list([]) 

149 

150 if len(self.operations) == 0: 

151 return KM 

152 

153 indices_order = [] 

154 for operation in self.operations: 

155 if len(operation.gate.KM) == 0: 

156 continue 

157 

158 indices_order += operation.indices 

159 

160 if len(KM) == 0: 

161 KM = deepcopy(operation.gate.KM) 

162 else: 

163 KM = KM ^ operation.gate.KM 

164 

165 if len(KM) == 0: 

166 return KM 

167 

168 register = self.operations[0].register 

169 missing_indices = [ 

170 i for i in range(len(register.dims)) if i not in indices_order 

171 ] 

172 

173 for j in missing_indices: 

174 KM = KM ^ identity(register.dims[j]) 

175 

176 combined_indices = indices_order + missing_indices 

177 sorted_ind = list(argsort(combined_indices)) 

178 

179 KM = KM.transpose(sorted_ind) 

180 

181 return KM 

182 

183 def gen_c_ops(self): 

184 c_ops = Qarray.from_list([]) 

185 

186 if len(self.operations) == 0: 

187 return c_ops 

188 

189 for operation in self.operations: 

190 if len(operation.gate.c_ops) == 0: 

191 continue 

192 promoted_c_ops = operation.promote(operation.gate.c_ops) 

193 c_ops = concatenate([c_ops, promoted_c_ops]) 

194 

195 return c_ops 

196 

197 def gen_ts(self): 

198 ts = None 

199 

200 for operation in self.operations: 

201 if operation.gate.ts is not None and len(operation.gate.ts) > 0: 

202 if ts is None: 

203 ts = operation.gate.ts 

204 else: 

205 assert jnp.array_equal(ts, operation.gate.ts), ( 

206 "All operations in a layer must have the same specified time steps, but not all operations need to have time steps." 

207 ) 

208 return ts 

209 

210@struct.dataclass 

211class Circuit: 

212 register: Register 

213 layers: List[Layer] = struct.field(pytree_node=False) 

214 

215 @classmethod 

216 def create(cls, register: Register, layers: Optional[List[Layer]] = None): 

217 if layers is None: 

218 layers = [] 

219 

220 return Circuit( 

221 register=register, 

222 layers=layers, 

223 ) 

224 

225 def append_layer(self, layer: Layer): 

226 self.layers.append(layer) 

227 

228 def append_operation( 

229 self, operation: Operation, default_simulate_mode: Optional[SimulateMode] = None, new_layer: bool =True 

230 ): 

231 assert operation.register == self.register, ( 

232 f"Mismatch in operation register {operation.register} and circuit register {self.register}." 

233 ) 

234 

235 new_layer = new_layer or len(self.layers) == 0 

236 

237 if new_layer: 

238 default_simulate_mode = default_simulate_mode if default_simulate_mode is not None else SimulateMode.UNITARY 

239 self.append_layer( 

240 Layer.create([operation], default_simulate_mode=default_simulate_mode) 

241 ) 

242 else: 

243 if default_simulate_mode is not None: 

244 assert ( 

245 self.layers[-1]._default_simulate_mode == default_simulate_mode 

246 ), "Cannot append operation to last layer with different default simulate mode." 

247 

248 self.layers[-1].add(operation) 

249 

250 def append( 

251 self, 

252 gate: Gate, 

253 indices: Union[int, List[int]], 

254 default_simulate_mode: Optional[SimulateMode] = None, 

255 new_layer: bool = True, 

256 ): 

257 operation = Operation.create(gate, indices, self.register) 

258 self.append_operation(operation, default_simulate_mode=default_simulate_mode, new_layer=new_layer)