Coverage for jaxquantum/circuits/circuits.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Circuits. 

2 

3Inspired by a mix of Cirq and Qiskit circuits. 

4""" 

5 

6from flax import struct 

7from jax import config 

8from typing import List, Optional, Union 

9from copy import deepcopy 

10from numpy import argsort 

11 

12from jaxquantum.core.operators import identity 

13from jaxquantum.circuits.gates import Gate 

14from jaxquantum.circuits.constants import SimulateMode 

15 

16config.update("jax_enable_x64", True) 

17 

18 

19@struct.dataclass 

20class Register: 

21 dims: List[int] = struct.field(pytree_node=False) 

22 

23 @classmethod 

24 def create(cls, dims: List[int]): 

25 return Register(dims=dims) 

26 

27 def __eq__(self, other): 

28 if isinstance(other, Register): 

29 return self.dims == other.dims 

30 return False 

31 

32 

33@struct.dataclass 

34class Operation: 

35 gate: Gate 

36 indices: List[int] = struct.field(pytree_node=False) 

37 register: Register 

38 

39 @classmethod 

40 def create(cls, gate: Gate, indices: Union[int, List[int]], register: Register): 

41 if isinstance(indices, int): 

42 indices = [indices] 

43 

44 assert gate.num_modes == len(indices), ( 

45 "Number of indices must match gate's num_modes." 

46 ) 

47 assert gate.dims == [register.dims[ind] for ind in indices], ( 

48 "Indices must match register dimensions." 

49 ) 

50 

51 if any( 

52 (0 > ind and ind >= len(register.dims)) or not isinstance(ind, int) 

53 for ind in indices 

54 ): 

55 raise ValueError("Indices must be integers within the register.") 

56 

57 return Operation(gate=gate, indices=indices, register=register) 

58 

59 

60@struct.dataclass 

61class Layer: 

62 operations: List[Operation] = struct.field(pytree_node=False) 

63 _unique_indices: List[int] = struct.field(pytree_node=False) 

64 _default_simulate_mode: SimulateMode = struct.field(pytree_node=False) 

65 

66 @classmethod 

67 def create( 

68 cls, operations: List[Operation], default_simulate_mode=SimulateMode.UNITARY 

69 ): 

70 all_indices = [ind for op in operations for ind in op.indices] 

71 unique_indices = list(set(all_indices)) 

72 if len(all_indices) != len(unique_indices): 

73 raise ValueError("Operations must not have overlapping indices.") 

74 return Layer( 

75 operations=operations, 

76 _unique_indices=unique_indices, 

77 _default_simulate_mode=default_simulate_mode, 

78 ) 

79 

80 def add(self, operation: Operation): 

81 if any(ind in self._unique_indices for ind in operation.indices): 

82 raise ValueError("Operations must not have overlapping indices.") 

83 self.operations.append(operation) 

84 self._unique_indices.extend(operation.indices) 

85 

86 def gen_U(self): 

87 U = None 

88 

89 if len(self.operations) == 0: 

90 return U 

91 

92 indices_order = [] 

93 for operation in self.operations: 

94 indices_order += operation.indices 

95 

96 if U is None: 

97 U = operation.gate.U 

98 else: 

99 U = U ^ operation.gate.U 

100 

101 register = self.operations[0].register 

102 missing_indices = [ 

103 i for i in range(len(register.dims)) if i not in indices_order 

104 ] 

105 

106 for j in missing_indices: 

107 U = U ^ identity(register.dims[j]) 

108 

109 combined_indices = indices_order + missing_indices 

110 

111 sorted_ind = list(argsort(combined_indices)) 

112 U = U.transpose(sorted_ind) 

113 return U 

114 

115 def gen_KM(self): 

116 KM = [] 

117 

118 if len(self.operations) == 0: 

119 return KM 

120 

121 indices_order = [] 

122 for operation in self.operations: 

123 indices_order += operation.indices 

124 

125 if len(KM) == 0: 

126 KM = deepcopy(operation.gate.KM) 

127 else: 

128 # updated_KM = [] 

129 # for op1 in KM: 

130 # for op2 in operation.gate.KM: 

131 # updated_KM.append(op1^op2) 

132 # KM = updated_KM 

133 KM = KM.arraytensor(operation.gate.KM) 

134 

135 register = self.operations[0].register 

136 missing_indices = [ 

137 i for i in range(len(register.dims)) if i not in indices_order 

138 ] 

139 

140 for j in missing_indices: 

141 KM = KM ^ identity(register.dims[j]) 

142 

143 combined_indices = indices_order + missing_indices 

144 sorted_ind = list(argsort(combined_indices)) 

145 

146 KM = KM.transpose(sorted_ind) 

147 

148 return KM 

149 

150 

151@struct.dataclass 

152class Circuit: 

153 register: Register 

154 layers: List[Layer] = struct.field(pytree_node=False) 

155 

156 @classmethod 

157 def create(cls, register: Register, layers: Optional[List[Layer]] = None): 

158 if layers is None: 

159 layers = [] 

160 

161 return Circuit( 

162 register=register, 

163 layers=layers, 

164 ) 

165 

166 def append_layer(self, layer: Layer): 

167 self.layers.append(layer) 

168 

169 def append_operation( 

170 self, operation: Operation, default_simulate_mode=SimulateMode.UNITARY 

171 ): 

172 assert operation.register == self.register, ( 

173 f"Mismatch in operation register {operation.register} and circuit register {self.register}." 

174 ) 

175 self.append_layer( 

176 Layer.create([operation], default_simulate_mode=default_simulate_mode) 

177 ) 

178 

179 def append( 

180 self, 

181 gate: Gate, 

182 indices: Union[int, List[int]], 

183 default_simulate_mode=SimulateMode.UNITARY, 

184 ): 

185 operation = Operation.create(gate, indices, self.register) 

186 self.append_operation(operation, default_simulate_mode=default_simulate_mode)