Coverage for jaxquantum / circuits / simulate.py: 92%
64 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""Circuit simulation methods."""
3from flax import struct
4from jax import config
5from typing import List
7from jaxquantum.core.qarray import Qarray, ket2dm
8from jaxquantum.circuits.circuits import Circuit, Layer
9from jaxquantum.circuits.constants import SimulateMode
10from jaxquantum.core.solvers import mesolve, sesolve, SolverOptions
13config.update("jax_enable_x64", True)
16@struct.dataclass
17class Results:
18 results: List[Qarray] = struct.field(pytree_node=False)
20 @classmethod
21 def create(cls, results: List[Qarray]):
22 return Results(results=results)
24 def __getitem__(self, j: int):
25 return self.results[j]
27 def __str__(self):
28 return self.__repr__()
30 def __repr__(self):
31 return str(self.results)
33 def append(self, result: Qarray):
34 self.results.append(result)
36 def __len__(self):
37 return len(self.results)
40def simulate(
41 circuit: Circuit, initial_state: Qarray, mode: SimulateMode = SimulateMode.DEFAULT, **kwargs
42) -> Results:
43 """
44 Simulates the evolution of a quantum state through a given quantum circuit.
46 Args:
47 circuit (Circuit): The quantum circuit to simulate. The circuit is composed of layers,
48 each of which can generate unitary or Kraus operators.
49 initial_state (Qarray): The initial quantum state to be evolved. This can be a state vector
50 or a density matrix.
51 mode (SimulateMode, optional): The mode of simulation. It can be either SimulateMode.UNITARY
52 for unitary evolution or SimulateMode.KRAUS for Kraus operator
53 evolution. Defaults to SimulateMode.UNITARY.
55 Returns:
56 Results: An object containing the results of the simulation, which includes the quantum states
57 at each step of the circuit.
58 """
60 results = Results.create([])
61 state = initial_state
62 results.append(Qarray.from_list([state]))
64 start_time = 0
66 for layer in circuit.layers:
67 result_dict = _simulate_layer(layer, state, mode=mode, start_time=start_time, **kwargs)
68 result = result_dict["result"]
69 start_time = result_dict["start_time"]
70 results.append(result)
71 state = result[-1]
73 return results
76def _simulate_layer(
77 layer: Layer, initial_state: Qarray, mode: SimulateMode = SimulateMode.UNITARY, start_time: float = 0, **kwargs
78) -> Qarray:
79 """
80 Simulates the evolution of a quantum state through a given layer.
82 Args:
83 layer (Layer): The layer through which the quantum state evolves.
84 This layer should have methods to generate unitary (gen_U)
85 and Kraus (gen_KM) operators.
86 initial_state (Qarray): The initial quantum state to be evolved.
87 This can be a state vector or a density matrix.
88 mode (SimulateMode, optional): The mode of simulation. It can be either
89 SimulateMode.UNITARY for unitary evolution
90 or SimulateMode.KRAUS for Kraus operator evolution
91 or SimulateMode.DEFAULT to use the default simulate mode in the layer.
92 Defaults to SimulateMode.UNITARY.
93 Returns:
94 Qarray: The result of the simulation containing the evolved quantum state.
95 """
97 state = initial_state
99 if mode == SimulateMode.DEFAULT:
100 mode = layer._default_simulate_mode
102 if mode == SimulateMode.UNITARY:
103 U = layer.gen_U()
104 if state.is_dm():
105 state = U @ state @ U.dag()
106 else:
107 state = U @ state
109 result = Qarray.from_list([state])
111 elif mode == SimulateMode.HAMILTONIAN:
113 solver_options = kwargs.pop("solver_options", SolverOptions.create(progress_meter=False))
115 Ht = layer.gen_Ht()
116 c_ops = layer.gen_c_ops()
117 ts = layer.gen_ts()
119 ts = ts + start_time
121 if state.is_dm() or (c_ops is not None and len(c_ops) > 0):
122 intermediate_states = mesolve(Ht, state, ts, c_ops=c_ops, solver_options=solver_options, **kwargs)
123 else:
124 intermediate_states = sesolve(Ht, state, ts, solver_options=solver_options, **kwargs)
126 result = intermediate_states
127 state = intermediate_states[-1]
128 start_time = ts[-1]
130 elif mode == SimulateMode.KRAUS:
131 KM = layer.gen_KM()
133 state = ket2dm(state)
134 state = (KM @ state @ KM.dag()).collapse()
135 result = Qarray.from_list([state])
137 return {
138 "result": result,
139 "start_time": start_time
140 }