Coverage for jaxquantum / circuits / simulate.py: 92%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""Circuit simulation methods.""" 

2 

3from flax import struct 

4from jax import config 

5from typing import List 

6 

7from jaxquantum.core.qarray import Qarray, ket2dm 

8from jaxquantum.circuits.circuits import Circuit, Layer 

9from jaxquantum.circuits.constants import SimulateMode 

10from jaxquantum.core.solvers import mesolve, sesolve, SolverOptions 

11 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16@struct.dataclass 

17class Results: 

18 results: List[Qarray] = struct.field(pytree_node=False) 

19 

20 @classmethod 

21 def create(cls, results: List[Qarray]): 

22 return Results(results=results) 

23 

24 def __getitem__(self, j: int): 

25 return self.results[j] 

26 

27 def __str__(self): 

28 return self.__repr__() 

29 

30 def __repr__(self): 

31 return str(self.results) 

32 

33 def append(self, result: Qarray): 

34 self.results.append(result) 

35 

36 def __len__(self): 

37 return len(self.results) 

38 

39 

40def simulate( 

41 circuit: Circuit, initial_state: Qarray, mode: SimulateMode = SimulateMode.DEFAULT, **kwargs 

42) -> Results: 

43 """ 

44 Simulates the evolution of a quantum state through a given quantum circuit. 

45 

46 Args: 

47 circuit (Circuit): The quantum circuit to simulate. The circuit is composed of layers, 

48 each of which can generate unitary or Kraus operators. 

49 initial_state (Qarray): The initial quantum state to be evolved. This can be a state vector 

50 or a density matrix. 

51 mode (SimulateMode, optional): The mode of simulation. It can be either SimulateMode.UNITARY 

52 for unitary evolution or SimulateMode.KRAUS for Kraus operator 

53 evolution. Defaults to SimulateMode.UNITARY. 

54 

55 Returns: 

56 Results: An object containing the results of the simulation, which includes the quantum states 

57 at each step of the circuit. 

58 """ 

59 

60 results = Results.create([]) 

61 state = initial_state 

62 results.append(Qarray.from_list([state])) 

63 

64 start_time = 0 

65 

66 for layer in circuit.layers: 

67 result_dict = _simulate_layer(layer, state, mode=mode, start_time=start_time, **kwargs) 

68 result = result_dict["result"] 

69 start_time = result_dict["start_time"] 

70 results.append(result) 

71 state = result[-1] 

72 

73 return results 

74 

75 

76def _simulate_layer( 

77 layer: Layer, initial_state: Qarray, mode: SimulateMode = SimulateMode.UNITARY, start_time: float = 0, **kwargs 

78) -> Qarray: 

79 """ 

80 Simulates the evolution of a quantum state through a given layer. 

81 

82 Args: 

83 layer (Layer): The layer through which the quantum state evolves. 

84 This layer should have methods to generate unitary (gen_U) 

85 and Kraus (gen_KM) operators. 

86 initial_state (Qarray): The initial quantum state to be evolved. 

87 This can be a state vector or a density matrix. 

88 mode (SimulateMode, optional): The mode of simulation. It can be either 

89 SimulateMode.UNITARY for unitary evolution 

90 or SimulateMode.KRAUS for Kraus operator evolution 

91 or SimulateMode.DEFAULT to use the default simulate mode in the layer. 

92 Defaults to SimulateMode.UNITARY. 

93 Returns: 

94 Qarray: The result of the simulation containing the evolved quantum state. 

95 """ 

96 

97 state = initial_state 

98 

99 if mode == SimulateMode.DEFAULT: 

100 mode = layer._default_simulate_mode 

101 

102 if mode == SimulateMode.UNITARY: 

103 U = layer.gen_U() 

104 if state.is_dm(): 

105 state = U @ state @ U.dag() 

106 else: 

107 state = U @ state 

108 

109 result = Qarray.from_list([state]) 

110 

111 elif mode == SimulateMode.HAMILTONIAN: 

112 

113 solver_options = kwargs.pop("solver_options", SolverOptions.create(progress_meter=False)) 

114 

115 Ht = layer.gen_Ht() 

116 c_ops = layer.gen_c_ops() 

117 ts = layer.gen_ts() 

118 

119 ts = ts + start_time 

120 

121 if state.is_dm() or (c_ops is not None and len(c_ops) > 0): 

122 intermediate_states = mesolve(Ht, state, ts, c_ops=c_ops, solver_options=solver_options, **kwargs) 

123 else: 

124 intermediate_states = sesolve(Ht, state, ts, solver_options=solver_options, **kwargs) 

125 

126 result = intermediate_states 

127 state = intermediate_states[-1] 

128 start_time = ts[-1] 

129 

130 elif mode == SimulateMode.KRAUS: 

131 KM = layer.gen_KM() 

132 

133 state = ket2dm(state) 

134 state = (KM @ state @ KM.dag()).collapse() 

135 result = Qarray.from_list([state]) 

136 

137 return { 

138 "result": result, 

139 "start_time": start_time 

140 }