Coverage for jaxquantum/circuits/simulate.py: 92%
65 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Circuit simulation methods."""
3from flax import struct
4from jax import config
5from typing import List
6from tqdm import tqdm
9from jaxquantum.core.qarray import Qarray, ket2dm
10from jaxquantum.circuits.circuits import Circuit, Layer
11from jaxquantum.circuits.constants import SimulateMode
12from jaxquantum.core.solvers import mesolve, sesolve, SolverOptions
15config.update("jax_enable_x64", True)
18@struct.dataclass
19class Results:
20 results: List[Qarray] = struct.field(pytree_node=False)
22 @classmethod
23 def create(cls, results: List[Qarray]):
24 return Results(results=results)
26 def __getitem__(self, j: int):
27 return self.results[j]
29 def __str__(self):
30 return self.__repr__()
32 def __repr__(self):
33 return str(self.results)
35 def append(self, result: Qarray):
36 self.results.append(result)
38 def __len__(self):
39 return len(self.results)
42def simulate(
43 circuit: Circuit, initial_state: Qarray, mode: SimulateMode = SimulateMode.DEFAULT, **kwargs
44) -> Results:
45 """
46 Simulates the evolution of a quantum state through a given quantum circuit.
48 Args:
49 circuit (Circuit): The quantum circuit to simulate. The circuit is composed of layers,
50 each of which can generate unitary or Kraus operators.
51 initial_state (Qarray): The initial quantum state to be evolved. This can be a state vector
52 or a density matrix.
53 mode (SimulateMode, optional): The mode of simulation. It can be either SimulateMode.UNITARY
54 for unitary evolution or SimulateMode.KRAUS for Kraus operator
55 evolution. Defaults to SimulateMode.UNITARY.
57 Returns:
58 Results: An object containing the results of the simulation, which includes the quantum states
59 at each step of the circuit.
60 """
62 results = Results.create([])
63 state = initial_state
64 results.append(Qarray.from_list([state]))
66 start_time = 0
68 for layer in circuit.layers:
69 result_dict = _simulate_layer(layer, state, mode=mode, start_time=start_time, **kwargs)
70 result = result_dict["result"]
71 start_time = result_dict["start_time"]
72 results.append(result)
73 state = result[-1]
75 return results
78def _simulate_layer(
79 layer: Layer, initial_state: Qarray, mode: SimulateMode = SimulateMode.UNITARY, start_time: float = 0, **kwargs
80) -> Qarray:
81 """
82 Simulates the evolution of a quantum state through a given layer.
84 Args:
85 layer (Layer): The layer through which the quantum state evolves.
86 This layer should have methods to generate unitary (gen_U)
87 and Kraus (gen_KM) operators.
88 initial_state (Qarray): The initial quantum state to be evolved.
89 This can be a state vector or a density matrix.
90 mode (SimulateMode, optional): The mode of simulation. It can be either
91 SimulateMode.UNITARY for unitary evolution
92 or SimulateMode.KRAUS for Kraus operator evolution
93 or SimulateMode.DEFAULT to use the default simulate mode in the layer.
94 Defaults to SimulateMode.UNITARY.
95 Returns:
96 Qarray: The result of the simulation containing the evolved quantum state.
97 """
99 state = initial_state
101 if mode == SimulateMode.DEFAULT:
102 mode = layer._default_simulate_mode
104 if mode == SimulateMode.UNITARY:
105 U = layer.gen_U()
106 if state.is_dm():
107 state = U @ state @ U.dag()
108 else:
109 state = U @ state
111 result = Qarray.from_list([state])
113 elif mode == SimulateMode.HAMILTONIAN:
115 solver_options = kwargs.get("solver_options", SolverOptions.create(progress_meter=False))
117 Ht = layer.gen_Ht()
118 c_ops = layer.gen_c_ops()
119 ts = layer.gen_ts()
121 ts = ts + start_time
123 if state.is_dm() or (c_ops is not None and len(c_ops) > 0):
124 intermediate_states = mesolve(Ht, state, ts, c_ops=c_ops, solver_options=solver_options)
125 else:
126 intermediate_states = sesolve(Ht, state, ts, solver_options=solver_options)
128 result = intermediate_states
129 state = intermediate_states[-1]
130 start_time = ts[-1]
132 elif mode == SimulateMode.KRAUS:
133 KM = layer.gen_KM()
135 state = ket2dm(state)
136 state = (KM @ state @ KM.dag()).collapse()
137 result = Qarray.from_list([state])
139 return {
140 "result": result,
141 "start_time": start_time
142 }