Coverage for jaxquantum/core/solvers.py: 100%

95 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Solvers""" 

2 

3from diffrax import ( 

4 diffeqsolve, 

5 ODETerm, 

6 SaveAt, 

7 PIDController, 

8 TqdmProgressMeter, 

9 NoProgressMeter, 

10) 

11from flax import struct 

12from jax import Array 

13from typing import Callable, Optional, Union 

14import diffrax 

15import jax.numpy as jnp 

16import warnings 

17import tqdm 

18import logging 

19 

20 

21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data 

22from jaxquantum.core.conversions import jnp2jqt 

23 

24# ---- 

25 

26 

27@struct.dataclass 

28class SolverOptions: 

29 progress_meter: bool = struct.field(pytree_node=False) 

30 solver: str = (struct.field(pytree_node=False),) 

31 max_steps: int = (struct.field(pytree_node=False),) 

32 

33 @classmethod 

34 def create( 

35 cls, 

36 progress_meter: bool = True, 

37 solver: str = "Tsit5", 

38 max_steps: int = 100_000, 

39 ): 

40 return cls(progress_meter, solver, max_steps) 

41 

42 

43class CustomProgressMeter(TqdmProgressMeter): 

44 @staticmethod 

45 def _init_bar() -> tqdm.tqdm: 

46 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" 

47 return tqdm.tqdm( 

48 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█" 

49 ) 

50 

51 

52def solve(f, ρ0, tlist, args, solver_options: Optional[SolverOptions] = None): 

53 """Gets teh desired solver from diffrax. 

54 

55 Args: 

56 solver_options: dictionary with solver options 

57 

58 Returns: 

59 solution 

60 """ 

61 

62 # f and ts 

63 term = ODETerm(f) 

64 saveat = SaveAt(ts=tlist) 

65 

66 # solver 

67 solver_options = solver_options or SolverOptions.create() 

68 

69 solver_name = solver_options.solver 

70 solver = getattr(diffrax, solver_name)() 

71 stepsize_controller = PIDController(rtol=1e-6, atol=1e-6) 

72 

73 # solve! 

74 with warnings.catch_warnings(): 

75 warnings.simplefilter( 

76 "ignore", UserWarning 

77 ) # NOTE: suppresses complex dtype warning in diffrax 

78 sol = diffeqsolve( 

79 term, 

80 solver, 

81 t0=tlist[0], 

82 t1=tlist[-1], 

83 dt0=tlist[1] - tlist[0], 

84 y0=ρ0, 

85 saveat=saveat, 

86 stepsize_controller=stepsize_controller, 

87 args=args, 

88 max_steps=solver_options.max_steps, 

89 progress_meter=CustomProgressMeter() 

90 if solver_options.progress_meter 

91 else NoProgressMeter(), 

92 ) 

93 

94 return sol 

95 

96 

97def mesolve( 

98 H: Union[Qarray, Callable[[float], Qarray]], 

99 rho0: Qarray, 

100 tlist: Array, 

101 c_ops: Optional[Qarray] = None, 

102 solver_options: Optional[SolverOptions] = None, 

103) -> Qarray: 

104 """Quantum Master Equation solver. 

105 

106 Args: 

107 H: time dependent Hamiltonian function or time-independent Qarray. 

108 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

109 tlist: time list 

110 c_ops: qarray list of collapse operators 

111 solver_options: SolverOptions with solver options 

112 

113 Returns: 

114 list of states 

115 """ 

116 

117 c_ops = c_ops if c_ops is not None else Qarray.from_list([]) 

118 

119 # if isinstance(H, Qarray): 

120 

121 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper: 

122 logging.warning( 

123 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

124 ) 

125 

126 ρ0 = rho0.to_dm() 

127 dims = ρ0.dims 

128 ρ0 = ρ0.data 

129 

130 c_ops = c_ops.data 

131 

132 if isinstance(H, Qarray): 

133 Ht_data = lambda t: H.data 

134 else: 

135 Ht_data = lambda t: H(t).data if H is not None else None 

136 

137 ys = _mesolve_data(Ht_data, ρ0, tlist, c_ops, solver_options=solver_options) 

138 

139 return jnp2jqt(ys, dims=dims) 

140 

141 

142def _mesolve_data( 

143 H: Callable[[float], Array], 

144 rho0: Array, 

145 tlist: Array, 

146 c_ops: Optional[Qarray] = None, 

147 solver_options: Optional[SolverOptions] = None, 

148) -> Array: 

149 """Quantum Master Equation solver. 

150 

151 Args: 

152 H: time dependent Hamiltonian function or time-independent Array. 

153 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

154 tlist: time list 

155 c_ops: qarray list of collapse operators 

156 solver_options: SolverOptions with solver options 

157 

158 Returns: 

159 list of states 

160 """ 

161 

162 c_ops = c_ops if c_ops is not None else jnp.array([]) 

163 

164 # check is in mesolve 

165 # if len(c_ops) == 0 and not is_dm_data(rho0): 

166 # logging.warning( 

167 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

168 # ) 

169 

170 ρ0 = rho0 + 0.0j 

171 

172 if len(c_ops) == 0: 

173 test_data = H(0.0) @ ρ0 

174 else: 

175 test_data = c_ops[0] @ H(0.0) @ ρ0 

176 

177 ρ0 = jnp.resize(ρ0, test_data.shape) # ensure correct shape 

178 

179 if len(c_ops) != 0: 

180 c_ops_bdims = c_ops.shape[:-2] 

181 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1]) 

182 

183 def f( 

184 t: float, 

185 rho: Array, 

186 c_ops_val: Array, 

187 ): 

188 H_val = H(t) # type: ignore 

189 H_val = H_val + 0.0j 

190 

191 rho_dot = -1j * (H_val @ rho - rho @ H_val) 

192 

193 if len(c_ops_val) == 0: 

194 return rho_dot 

195 

196 c_ops_val_dag = dag_data(c_ops_val) 

197 

198 rho_dot_delta = 0.5 * ( 

199 2 * c_ops_val @ rho @ c_ops_val_dag 

200 - rho @ c_ops_val_dag @ c_ops_val 

201 - c_ops_val_dag @ c_ops_val @ rho 

202 ) 

203 

204 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0) 

205 

206 rho_dot += rho_dot_delta 

207 

208 return rho_dot 

209 

210 sol = solve(f, ρ0, tlist, c_ops, solver_options=solver_options) 

211 

212 return sol.ys 

213 

214 

215def sesolve( 

216 H: Union[Qarray, Callable[[float], Qarray]], 

217 rho0: Qarray, 

218 tlist: Array, 

219 solver_options: Optional[SolverOptions] = None, 

220) -> Qarray: 

221 """Schrödinger Equation solver. 

222 

223 Args: 

224 H: time dependent Hamiltonian function or time-independent Qarray. 

225 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

226 tlist: time list 

227 solver_options: SolverOptions with solver options 

228 

229 Returns: 

230 list of states 

231 """ 

232 

233 ψ = rho0 

234 

235 if ψ.qtype == Qtypes.oper: 

236 raise ValueError( 

237 "Please use `jqt.mesolve` for initial state inputs in density matrix form." 

238 ) 

239 

240 ψ = ψ.to_ket() 

241 dims = ψ.dims 

242 ψ = ψ.data 

243 

244 if isinstance(H, Qarray): 

245 Ht_data = lambda t: H.data 

246 else: 

247 Ht_data = lambda t: H(t).data if H is not None else None 

248 

249 ys = _sesolve_data(Ht_data, ψ, tlist, solver_options=solver_options) 

250 

251 return jnp2jqt(ys, dims=dims) 

252 

253 

254def _sesolve_data( 

255 H: Callable[[float], Array], 

256 rho0: Array, 

257 tlist: Array, 

258 solver_options: Optional[SolverOptions] = None, 

259): 

260 """Schrödinger Equation solver. 

261 

262 Args: 

263 H: time dependent Hamiltonian function or time-independent Array. 

264 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

265 tlist: time list 

266 solver_options: SolverOptions with solver options 

267 

268 Returns: 

269 list of states 

270 """ 

271 

272 ψ = rho0 

273 ψ = ψ + 0.0j 

274 

275 def f(t: float, ψₜ: Array, _): 

276 H_val = H(t) # type: ignore 

277 H_val = H_val + 0.0j 

278 

279 ψₜ_dot = -1j * (H_val @ ψₜ) 

280 

281 return ψₜ_dot 

282 

283 ψ_test = f(0, ψ, None) 

284 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape 

285 

286 sol = solve(f, ψ, tlist, None, solver_options=solver_options) 

287 return sol.ys 

288 

289 # ---- 

290 

291 # propagators 

292 # ---- 

293 

294 # def propagator( 

295 # H: Union[Qarray, Callable[[float], Qarray]], 

296 # t: Union[float, Array], 

297 # solver_options=None 

298 # ): 

299 # """ Generate the propagator for a time dependent Hamiltonian. 

300 

301 # Args: 

302 # H (Qarray or callable): 

303 # A Qarray static Hamiltonian OR 

304 # a function that takes a time argument and returns a Hamiltonian. 

305 # ts (float or Array): 

306 # A single time point or 

307 # an Array of time points. 

308 

309 # Returns: 

310 # Qarray or List[Qarray]: 

311 # The propagator for the Hamiltonian at time t. 

312 # OR a list of propagators for the Hamiltonian at each time in t. 

313 

314 # """ 

315 

316 # t_is_scalar = robust_isscalar(t) 

317 

318 # if isinstance(H, Qarray): 

319 # dims = H.dims 

320 # if t_is_scalar: 

321 # if t == 0: 

322 # return identity_like(H) 

323 

324 # return jnp2jqt(propagator_0_data(H.data,t), dims=dims) 

325 # else: 

326 # f = lambda t: propagator_0_data(H.data,t) 

327 # return jnp2jqt(vmap(f)(t), dims) 

328 # else: 

329 # dims = H(0.0).dims 

330 # H_data = lambda t: H(t).data 

331 # if t_is_scalar: 

332 # if t == 0: 

333 # return identity_like(H(0.0)) 

334 

335 # ts = jnp.linspace(0,t,2) 

336 # return jnp2jqt( 

337 # propagator_t_data(H_data, ts, solver_options=solver_options)[1], 

338 # dims=dims 

339 # ) 

340 # else: 

341 # ts = t 

342 # U_props = propagator_t_data(H_data, ts, solver_options=solver_options) 

343 # return jnp2jqt(U_props, dims) 

344 

345 # def propagator_0_data( 

346 # H0: Array, 

347 # t: float 

348 # ): 

349 # """ Generate the propagator for a time independent Hamiltonian. 

350 

351 # Args: 

352 # H0 (Qarray): The Hamiltonian. 

353 

354 # Returns: 

355 # Qarray: The propagator for the time independent Hamiltonian. 

356 # """ 

357 # return jsp.linalg.expm(-1j * H0 * t) 

358 

359 # def propagator_t_data( 

360 # Ht: Callable[[float], Array], 

361 # ts: Array, 

362 # solver_options=None 

363 # ): 

364 """ Generate the propagator for a time dependent Hamiltonian.  

365 

366 Args: 

367 ts (float): The final time of the propagator.  

368 Warning: Do not send in t. In this case, just do exp(-1j*Ht(0.0)). 

369 Ht (callable): A function that takes a time argument and returns a Hamiltonian.  

370 solver_options (dict): Options to pass to the solver. 

371 

372 Returns: 

373 Qarray: The propagator for the time dependent Hamiltonian for the time range [0, t_final]. 

374 """ 

375 # N = Ht(0).shape[0] 

376 # basis_states = jnp.eye(N) 

377 

378 # def propogate_state(initial_state): 

379 # return sesolve_data(initial_state, ts, Ht=Ht, solver_options=solver_options) 

380 

381 # U_prop = vmap(propogate_state)(basis_states) 

382 # U_prop = U_prop.transpose(1,0,2) # move time axis to the front 

383 # return U_prop