Coverage for jaxquantum/circuits/circuits.py: 0%
103 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Circuits.
3Inspired by a mix of Cirq and Qiskit circuits.
4"""
6from flax import struct
7from jax import config
8from typing import List, Optional, Union
9from copy import deepcopy
10from numpy import argsort
12from jaxquantum.core.operators import identity
13from jaxquantum.circuits.gates import Gate
14from jaxquantum.circuits.constants import SimulateMode
16config.update("jax_enable_x64", True)
19@struct.dataclass
20class Register:
21 dims: List[int] = struct.field(pytree_node=False)
23 @classmethod
24 def create(cls, dims: List[int]):
25 return Register(dims=dims)
27 def __eq__(self, other):
28 if isinstance(other, Register):
29 return self.dims == other.dims
30 return False
33@struct.dataclass
34class Operation:
35 gate: Gate
36 indices: List[int] = struct.field(pytree_node=False)
37 register: Register
39 @classmethod
40 def create(cls, gate: Gate, indices: Union[int, List[int]], register: Register):
41 if isinstance(indices, int):
42 indices = [indices]
44 assert gate.num_modes == len(indices), (
45 "Number of indices must match gate's num_modes."
46 )
47 assert gate.dims == [register.dims[ind] for ind in indices], (
48 "Indices must match register dimensions."
49 )
51 if any(
52 (0 > ind and ind >= len(register.dims)) or not isinstance(ind, int)
53 for ind in indices
54 ):
55 raise ValueError("Indices must be integers within the register.")
57 return Operation(gate=gate, indices=indices, register=register)
60@struct.dataclass
61class Layer:
62 operations: List[Operation] = struct.field(pytree_node=False)
63 _unique_indices: List[int] = struct.field(pytree_node=False)
64 _default_simulate_mode: SimulateMode = struct.field(pytree_node=False)
66 @classmethod
67 def create(
68 cls, operations: List[Operation], default_simulate_mode=SimulateMode.UNITARY
69 ):
70 all_indices = [ind for op in operations for ind in op.indices]
71 unique_indices = list(set(all_indices))
72 if len(all_indices) != len(unique_indices):
73 raise ValueError("Operations must not have overlapping indices.")
74 return Layer(
75 operations=operations,
76 _unique_indices=unique_indices,
77 _default_simulate_mode=default_simulate_mode,
78 )
80 def add(self, operation: Operation):
81 if any(ind in self._unique_indices for ind in operation.indices):
82 raise ValueError("Operations must not have overlapping indices.")
83 self.operations.append(operation)
84 self._unique_indices.extend(operation.indices)
86 def gen_U(self):
87 U = None
89 if len(self.operations) == 0:
90 return U
92 indices_order = []
93 for operation in self.operations:
94 indices_order += operation.indices
96 if U is None:
97 U = operation.gate.U
98 else:
99 U = U ^ operation.gate.U
101 register = self.operations[0].register
102 missing_indices = [
103 i for i in range(len(register.dims)) if i not in indices_order
104 ]
106 for j in missing_indices:
107 U = U ^ identity(register.dims[j])
109 combined_indices = indices_order + missing_indices
111 sorted_ind = list(argsort(combined_indices))
112 U = U.transpose(sorted_ind)
113 return U
115 def gen_KM(self):
116 KM = []
118 if len(self.operations) == 0:
119 return KM
121 indices_order = []
122 for operation in self.operations:
123 indices_order += operation.indices
125 if len(KM) == 0:
126 KM = deepcopy(operation.gate.KM)
127 else:
128 # updated_KM = []
129 # for op1 in KM:
130 # for op2 in operation.gate.KM:
131 # updated_KM.append(op1^op2)
132 # KM = updated_KM
133 KM = KM.arraytensor(operation.gate.KM)
135 register = self.operations[0].register
136 missing_indices = [
137 i for i in range(len(register.dims)) if i not in indices_order
138 ]
140 for j in missing_indices:
141 KM = KM ^ identity(register.dims[j])
143 combined_indices = indices_order + missing_indices
144 sorted_ind = list(argsort(combined_indices))
146 KM = KM.transpose(sorted_ind)
148 return KM
151@struct.dataclass
152class Circuit:
153 register: Register
154 layers: List[Layer] = struct.field(pytree_node=False)
156 @classmethod
157 def create(cls, register: Register, layers: Optional[List[Layer]] = None):
158 if layers is None:
159 layers = []
161 return Circuit(
162 register=register,
163 layers=layers,
164 )
166 def append_layer(self, layer: Layer):
167 self.layers.append(layer)
169 def append_operation(
170 self, operation: Operation, default_simulate_mode=SimulateMode.UNITARY
171 ):
172 assert operation.register == self.register, (
173 f"Mismatch in operation register {operation.register} and circuit register {self.register}."
174 )
175 self.append_layer(
176 Layer.create([operation], default_simulate_mode=default_simulate_mode)
177 )
179 def append(
180 self,
181 gate: Gate,
182 indices: Union[int, List[int]],
183 default_simulate_mode=SimulateMode.UNITARY,
184 ):
185 operation = Operation.create(gate, indices, self.register)
186 self.append_operation(operation, default_simulate_mode=default_simulate_mode)