Coverage for jaxquantum/circuits/simulate.py: 92%

65 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Circuit simulation methods.""" 

2 

3from flax import struct 

4from jax import config 

5from typing import List 

6from tqdm import tqdm 

7 

8 

9from jaxquantum.core.qarray import Qarray, ket2dm 

10from jaxquantum.circuits.circuits import Circuit, Layer 

11from jaxquantum.circuits.constants import SimulateMode 

12from jaxquantum.core.solvers import mesolve, sesolve, SolverOptions 

13 

14 

15config.update("jax_enable_x64", True) 

16 

17 

18@struct.dataclass 

19class Results: 

20 results: List[Qarray] = struct.field(pytree_node=False) 

21 

22 @classmethod 

23 def create(cls, results: List[Qarray]): 

24 return Results(results=results) 

25 

26 def __getitem__(self, j: int): 

27 return self.results[j] 

28 

29 def __str__(self): 

30 return self.__repr__() 

31 

32 def __repr__(self): 

33 return str(self.results) 

34 

35 def append(self, result: Qarray): 

36 self.results.append(result) 

37 

38 def __len__(self): 

39 return len(self.results) 

40 

41 

42def simulate( 

43 circuit: Circuit, initial_state: Qarray, mode: SimulateMode = SimulateMode.DEFAULT, **kwargs 

44) -> Results: 

45 """ 

46 Simulates the evolution of a quantum state through a given quantum circuit. 

47 

48 Args: 

49 circuit (Circuit): The quantum circuit to simulate. The circuit is composed of layers, 

50 each of which can generate unitary or Kraus operators. 

51 initial_state (Qarray): The initial quantum state to be evolved. This can be a state vector 

52 or a density matrix. 

53 mode (SimulateMode, optional): The mode of simulation. It can be either SimulateMode.UNITARY 

54 for unitary evolution or SimulateMode.KRAUS for Kraus operator 

55 evolution. Defaults to SimulateMode.UNITARY. 

56 

57 Returns: 

58 Results: An object containing the results of the simulation, which includes the quantum states 

59 at each step of the circuit. 

60 """ 

61 

62 results = Results.create([]) 

63 state = initial_state 

64 results.append(Qarray.from_list([state])) 

65 

66 start_time = 0 

67 

68 for layer in circuit.layers: 

69 result_dict = _simulate_layer(layer, state, mode=mode, start_time=start_time, **kwargs) 

70 result = result_dict["result"] 

71 start_time = result_dict["start_time"] 

72 results.append(result) 

73 state = result[-1] 

74 

75 return results 

76 

77 

78def _simulate_layer( 

79 layer: Layer, initial_state: Qarray, mode: SimulateMode = SimulateMode.UNITARY, start_time: float = 0, **kwargs 

80) -> Qarray: 

81 """ 

82 Simulates the evolution of a quantum state through a given layer. 

83 

84 Args: 

85 layer (Layer): The layer through which the quantum state evolves. 

86 This layer should have methods to generate unitary (gen_U) 

87 and Kraus (gen_KM) operators. 

88 initial_state (Qarray): The initial quantum state to be evolved. 

89 This can be a state vector or a density matrix. 

90 mode (SimulateMode, optional): The mode of simulation. It can be either 

91 SimulateMode.UNITARY for unitary evolution 

92 or SimulateMode.KRAUS for Kraus operator evolution 

93 or SimulateMode.DEFAULT to use the default simulate mode in the layer. 

94 Defaults to SimulateMode.UNITARY. 

95 Returns: 

96 Qarray: The result of the simulation containing the evolved quantum state. 

97 """ 

98 

99 state = initial_state 

100 

101 if mode == SimulateMode.DEFAULT: 

102 mode = layer._default_simulate_mode 

103 

104 if mode == SimulateMode.UNITARY: 

105 U = layer.gen_U() 

106 if state.is_dm(): 

107 state = U @ state @ U.dag() 

108 else: 

109 state = U @ state 

110 

111 result = Qarray.from_list([state]) 

112 

113 elif mode == SimulateMode.HAMILTONIAN: 

114 

115 solver_options = kwargs.get("solver_options", SolverOptions.create(progress_meter=False)) 

116 

117 Ht = layer.gen_Ht() 

118 c_ops = layer.gen_c_ops() 

119 ts = layer.gen_ts() 

120 

121 ts = ts + start_time 

122 

123 if state.is_dm() or (c_ops is not None and len(c_ops) > 0): 

124 intermediate_states = mesolve(Ht, state, ts, c_ops=c_ops, solver_options=solver_options) 

125 else: 

126 intermediate_states = sesolve(Ht, state, ts, solver_options=solver_options) 

127 

128 result = intermediate_states 

129 state = intermediate_states[-1] 

130 start_time = ts[-1] 

131 

132 elif mode == SimulateMode.KRAUS: 

133 KM = layer.gen_KM() 

134 

135 state = ket2dm(state) 

136 state = (KM @ state @ KM.dag()).collapse() 

137 result = Qarray.from_list([state]) 

138 

139 return { 

140 "result": result, 

141 "start_time": start_time 

142 }