Coverage for jaxquantum/circuits/circuits.py: 84%
152 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Circuits.
3Inspired by a mix of Cirq and Qiskit circuits.
4"""
6from flax import struct
7from jax import config
8from typing import List, Optional, Union
9from copy import deepcopy
10from numpy import argsort
11import jax.numpy as jnp
13from jaxquantum.core.operators import identity
14from jaxquantum.circuits.gates import Gate
15from jaxquantum.circuits.constants import SimulateMode
16from jaxquantum.core.qarray import Qarray, concatenate
19config.update("jax_enable_x64", True)
22@struct.dataclass
23class Register:
24 dims: List[int] = struct.field(pytree_node=False)
26 @classmethod
27 def create(cls, dims: List[int]):
28 return Register(dims=dims)
30 def __eq__(self, other):
31 if isinstance(other, Register):
32 return self.dims == other.dims
33 return False
36@struct.dataclass
37class Operation:
38 gate: Gate
39 indices: List[int] = struct.field(pytree_node=False)
40 register: Register
42 @classmethod
43 def create(cls, gate: Gate, indices: Union[int, List[int]], register: Register):
44 if isinstance(indices, int):
45 indices = [indices]
47 assert gate.num_modes == len(indices), (
48 "Number of indices must match gate's num_modes."
49 )
50 assert gate.dims == [register.dims[ind] for ind in indices], (
51 "Indices must match register dimensions."
52 )
54 if any(
55 (0 > ind and ind >= len(register.dims)) or not isinstance(ind, int)
56 for ind in indices
57 ):
58 raise ValueError("Indices must be integers within the register.")
60 return Operation(gate=gate, indices=indices, register=register)
63 def promote(self, op: Qarray) -> Qarray:
64 indices_order = self.indices
65 missing_indices = [
66 i for i in range(len(self.register.dims)) if i not in indices_order
67 ]
68 for j in missing_indices:
69 op = op ^ identity(self.register.dims[j])
70 combined_indices = indices_order + missing_indices
71 sorted_ind = list(argsort(combined_indices))
72 op = op.transpose(sorted_ind)
73 return op
76@struct.dataclass
77class Layer:
78 operations: List[Operation] = struct.field(pytree_node=False)
79 _unique_indices: List[int] = struct.field(pytree_node=False)
80 _default_simulate_mode: SimulateMode = struct.field(pytree_node=False)
82 @classmethod
83 def create(
84 cls, operations: List[Operation], default_simulate_mode=SimulateMode.UNITARY
85 ):
86 all_indices = [ind for op in operations for ind in op.indices]
87 unique_indices = list(set(all_indices))
89 if default_simulate_mode != SimulateMode.HAMILTONIAN:
90 if len(all_indices) != len(unique_indices):
91 raise ValueError("Operations must not have overlapping indices.")
93 return Layer(
94 operations=operations,
95 _unique_indices=unique_indices,
96 _default_simulate_mode=default_simulate_mode,
97 )
99 def add(self, operation: Operation):
100 if self._default_simulate_mode != SimulateMode.HAMILTONIAN:
101 if any(ind in self._unique_indices for ind in operation.indices):
102 raise ValueError("Operations must not have overlapping indices.")
103 self.operations.append(operation)
104 self._unique_indices.extend(operation.indices)
106 def gen_U(self):
107 U = None
109 if len(self.operations) == 0:
110 return None
112 indices_order = []
113 for operation in self.operations:
114 indices_order += operation.indices
116 if U is None:
117 U = operation.gate.U
118 else:
119 U = U ^ operation.gate.U
121 register = self.operations[0].register
122 missing_indices = [
123 i for i in range(len(register.dims)) if i not in indices_order
124 ]
126 for j in missing_indices:
127 U = U ^ identity(register.dims[j])
129 combined_indices = indices_order + missing_indices
131 sorted_ind = list(argsort(combined_indices))
132 U = U.transpose(sorted_ind)
133 return U
135 def gen_Ht(self):
136 Ht = lambda t: 0
138 if len(self.operations) == 0:
139 return Ht
141 for operation in self.operations:
142 def Ht(t, prev_Ht=Ht, prev_operation=operation):
143 return prev_Ht(t) + prev_operation.promote(prev_operation.gate.Ht(t))
145 return Ht
147 def gen_KM(self):
148 KM = Qarray.from_list([])
150 if len(self.operations) == 0:
151 return KM
153 indices_order = []
154 for operation in self.operations:
155 if len(operation.gate.KM) == 0:
156 continue
158 indices_order += operation.indices
160 if len(KM) == 0:
161 KM = deepcopy(operation.gate.KM)
162 else:
163 KM = KM ^ operation.gate.KM
165 if len(KM) == 0:
166 return KM
168 register = self.operations[0].register
169 missing_indices = [
170 i for i in range(len(register.dims)) if i not in indices_order
171 ]
173 for j in missing_indices:
174 KM = KM ^ identity(register.dims[j])
176 combined_indices = indices_order + missing_indices
177 sorted_ind = list(argsort(combined_indices))
179 KM = KM.transpose(sorted_ind)
181 return KM
183 def gen_c_ops(self):
184 c_ops = Qarray.from_list([])
186 if len(self.operations) == 0:
187 return c_ops
189 for operation in self.operations:
190 if len(operation.gate.c_ops) == 0:
191 continue
192 promoted_c_ops = operation.promote(operation.gate.c_ops)
193 c_ops = concatenate([c_ops, promoted_c_ops])
195 return c_ops
197 def gen_ts(self):
198 ts = None
200 for operation in self.operations:
201 if operation.gate.ts is not None and len(operation.gate.ts) > 0:
202 if ts is None:
203 ts = operation.gate.ts
204 else:
205 assert jnp.array_equal(ts, operation.gate.ts), (
206 "All operations in a layer must have the same specified time steps, but not all operations need to have time steps."
207 )
208 return ts
210@struct.dataclass
211class Circuit:
212 register: Register
213 layers: List[Layer] = struct.field(pytree_node=False)
215 @classmethod
216 def create(cls, register: Register, layers: Optional[List[Layer]] = None):
217 if layers is None:
218 layers = []
220 return Circuit(
221 register=register,
222 layers=layers,
223 )
225 def append_layer(self, layer: Layer):
226 self.layers.append(layer)
228 def append_operation(
229 self, operation: Operation, default_simulate_mode: Optional[SimulateMode] = None, new_layer: bool =True
230 ):
231 assert operation.register == self.register, (
232 f"Mismatch in operation register {operation.register} and circuit register {self.register}."
233 )
235 new_layer = new_layer or len(self.layers) == 0
237 if new_layer:
238 default_simulate_mode = default_simulate_mode if default_simulate_mode is not None else SimulateMode.UNITARY
239 self.append_layer(
240 Layer.create([operation], default_simulate_mode=default_simulate_mode)
241 )
242 else:
243 if default_simulate_mode is not None:
244 assert (
245 self.layers[-1]._default_simulate_mode == default_simulate_mode
246 ), "Cannot append operation to last layer with different default simulate mode."
248 self.layers[-1].add(operation)
250 def append(
251 self,
252 gate: Gate,
253 indices: Union[int, List[int]],
254 default_simulate_mode: Optional[SimulateMode] = None,
255 new_layer: bool = True,
256 ):
257 operation = Operation.create(gate, indices, self.register)
258 self.append_operation(operation, default_simulate_mode=default_simulate_mode, new_layer=new_layer)