Coverage for jaxquantum/core/solvers.py: 86%

118 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1"""Solvers""" 

2 

3from diffrax import ( 

4 diffeqsolve, 

5 ODETerm, 

6 SaveAt, 

7 PIDController, 

8 TqdmProgressMeter, 

9 NoProgressMeter, 

10) 

11from flax import struct 

12from jax import Array 

13from typing import Callable, Optional, Union 

14import diffrax 

15import jax.numpy as jnp 

16import warnings 

17import tqdm 

18import logging 

19 

20 

21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data 

22from jaxquantum.core.conversions import jnp2jqt 

23from jaxquantum.core.operators import identity_like, multi_mode_basis_set 

24from jaxquantum.utils.utils import robust_isscalar 

25 

26# ---- 

27 

28 

29@struct.dataclass 

30class SolverOptions: 

31 progress_meter: bool = struct.field(pytree_node=False) 

32 solver: str = (struct.field(pytree_node=False),) 

33 max_steps: int = (struct.field(pytree_node=False),) 

34 rtol: float = (struct.field(pytree_node=False),) 

35 atol: float = (struct.field(pytree_node=False),) 

36 

37 @classmethod 

38 def create( 

39 cls, 

40 progress_meter: bool = True, 

41 solver: str = "Tsit5", 

42 max_steps: int = 100_000, 

43 rtol: float = 1e-7, 

44 atol: float = 1e-9, 

45 ): 

46 return cls(progress_meter, solver, max_steps, rtol, atol) 

47 

48 

49class CustomProgressMeter(TqdmProgressMeter): 

50 @staticmethod 

51 def _init_bar() -> tqdm.tqdm: 

52 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" 

53 return tqdm.tqdm( 

54 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█" 

55 ) 

56 

57 

58def solve(f, ρ0, tlist, saveat_tlist, args, solver_options: Optional[ 

59 SolverOptions] = None): 

60 """Gets teh desired solver from diffrax. 

61 

62 Args: 

63 f: function defining the ODE 

64 ρ0: initial state 

65 tlist: time list 

66 saveat_tlist: list of times at which to save the state 

67 pass in [-1] to save only at final time 

68 args: additional arguments to f 

69 solver_options: dictionary with solver options 

70 

71 Returns: 

72 solution 

73 """ 

74 

75 # f and ts 

76 term = ODETerm(f) 

77 

78 if saveat_tlist.shape[0] == 1 and saveat_tlist == -1: 

79 saveat = SaveAt(t1=True) 

80 else: 

81 saveat = SaveAt(ts=saveat_tlist) 

82 

83 # solver 

84 solver_options = solver_options or SolverOptions.create() 

85 

86 solver_name = solver_options.solver 

87 solver = getattr(diffrax, solver_name)() 

88 stepsize_controller = PIDController(rtol=solver_options.rtol, atol=solver_options.atol) 

89 

90 # solve! 

91 with warnings.catch_warnings(): 

92 warnings.filterwarnings("ignore", 

93 message="Complex dtype support in Diffrax", 

94 category=UserWarning) # NOTE: suppresses complex dtype warning in diffrax 

95 sol = diffeqsolve( 

96 term, 

97 solver, 

98 t0=tlist[0], 

99 t1=tlist[-1], 

100 dt0=tlist[1] - tlist[0], 

101 y0=ρ0, 

102 saveat=saveat, 

103 stepsize_controller=stepsize_controller, 

104 args=args, 

105 max_steps=solver_options.max_steps, 

106 progress_meter=CustomProgressMeter() 

107 if solver_options.progress_meter 

108 else NoProgressMeter(), 

109 ) 

110 

111 return sol 

112 

113 

114def mesolve( 

115 H: Union[Qarray, Callable[[float], Qarray]], 

116 rho0: Qarray, 

117 tlist: Array, 

118 saveat_tlist: Optional[Array] = None, 

119 c_ops: Optional[Qarray] = None, 

120 solver_options: Optional[SolverOptions] = None, 

121) -> Qarray: 

122 """Quantum Master Equation solver. 

123 

124 Args: 

125 H: time dependent Hamiltonian function or time-independent Qarray. 

126 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

127 tlist: time list 

128 saveat_tlist: list of times at which to save the state. 

129 If -1 or [-1], save only at final time. 

130 If None, save at all times in tlist. Default: None. 

131 c_ops: qarray list of collapse operators 

132 solver_options: SolverOptions with solver options 

133 

134 Returns: 

135 list of states 

136 """ 

137 

138 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist 

139 

140 saveat_tlist = jnp.atleast_1d(saveat_tlist) 

141 

142 c_ops = c_ops if c_ops is not None else Qarray.from_list([]) 

143 

144 # if isinstance(H, Qarray): 

145 

146 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper: 

147 logging.warning( 

148 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

149 ) 

150 

151 ρ0 = rho0.to_dm() 

152 dims = ρ0.dims 

153 ρ0 = ρ0.data 

154 

155 c_ops = c_ops.data 

156 

157 if isinstance(H, Qarray): 

158 Ht_data = lambda t: H.data 

159 else: 

160 Ht_data = lambda t: H(t).data 

161 

162 ys = _mesolve_data(Ht_data, ρ0, tlist, saveat_tlist, c_ops, 

163 solver_options=solver_options) 

164 

165 return jnp2jqt(ys, dims=dims) 

166 

167 

168def _mesolve_data( 

169 H: Callable[[float], Array], 

170 rho0: Array, 

171 tlist: Array, 

172 saveat_tlist: Array, 

173 c_ops: Optional[Qarray] = None, 

174 solver_options: Optional[SolverOptions] = None, 

175) -> Array: 

176 """Quantum Master Equation solver. 

177 

178 Args: 

179 H: time dependent Hamiltonian function or time-independent Array. 

180 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

181 tlist: time list 

182 saveat_tlist: list of times at which to save the state 

183 If -1 or [-1], save only at final time. 

184 If None, save at all times in tlist. Default: None. 

185 c_ops: qarray list of collapse operators 

186 solver_options: SolverOptions with solver options 

187 

188 Returns: 

189 list of states 

190 """ 

191 

192 c_ops = c_ops if c_ops is not None else jnp.array([]) 

193 

194 # check is in mesolve 

195 # if len(c_ops) == 0 and not is_dm_data(rho0): 

196 # logging.warning( 

197 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

198 # ) 

199 

200 ρ0 = rho0 

201 

202 if len(c_ops) == 0: 

203 test_data = H(0.0) @ ρ0 

204 else: 

205 test_data = c_ops[0] @ H(0.0) @ ρ0 

206 

207 ρ0 = jnp.resize(ρ0, test_data.shape) # ensure correct shape 

208 

209 if len(c_ops) != 0: 

210 c_ops_bdims = c_ops.shape[:-2] 

211 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1]) 

212 

213 def f( 

214 t: float, 

215 rho: Array, 

216 c_ops_val: Array, 

217 ): 

218 H_val = H(t) # type: ignore 

219 

220 rho_dot = -1j * (H_val @ rho - rho @ H_val) 

221 

222 if len(c_ops_val) == 0: 

223 return rho_dot 

224 

225 c_ops_val_dag = dag_data(c_ops_val) 

226 

227 rho_dot_delta = 0.5 * ( 

228 2 * c_ops_val @ rho @ c_ops_val_dag 

229 - rho @ c_ops_val_dag @ c_ops_val 

230 - c_ops_val_dag @ c_ops_val @ rho 

231 ) 

232 

233 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0) 

234 

235 rho_dot += rho_dot_delta 

236 

237 return rho_dot 

238 

239 sol = solve(f, ρ0, tlist, saveat_tlist, c_ops, 

240 solver_options=solver_options) 

241 

242 return sol.ys 

243 

244 

245def sesolve( 

246 H: Union[Qarray, Callable[[float], Qarray]], 

247 rho0: Qarray, 

248 tlist: Array, 

249 saveat_tlist: Optional[Array] = None, 

250 solver_options: Optional[SolverOptions] = None, 

251) -> Qarray: 

252 """Schrödinger Equation solver. 

253 

254 Args: 

255 H: time dependent Hamiltonian function or time-independent Qarray. 

256 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

257 tlist: time list 

258 saveat_tlist: list of times at which to save the state. 

259 If -1 or [-1], save only at final time. 

260 If None, save at all times in tlist. Default: None. 

261 solver_options: SolverOptions with solver options 

262 

263 Returns: 

264 list of states 

265 """ 

266 

267 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist 

268 

269 saveat_tlist = jnp.atleast_1d(saveat_tlist) 

270 

271 ψ = rho0 

272 

273 if ψ.qtype == Qtypes.oper: 

274 raise ValueError( 

275 "Please use `jqt.mesolve` for initial state inputs in density matrix form." 

276 ) 

277 

278 ψ = ψ.to_ket() 

279 dims = ψ.dims 

280 ψ = ψ.data 

281 

282 if isinstance(H, Qarray): 

283 Ht_data = lambda t: H.data 

284 else: 

285 Ht_data = lambda t: H(t).data 

286 

287 ys = _sesolve_data(Ht_data, ψ, tlist, saveat_tlist, 

288 solver_options=solver_options) 

289 

290 return jnp2jqt(ys, dims=dims) 

291 

292 

293def _sesolve_data( 

294 H: Callable[[float], Array], 

295 rho0: Array, 

296 tlist: Array, 

297 saveat_tlist: Array, 

298 solver_options: Optional[SolverOptions] = None, 

299): 

300 """Schrödinger Equation solver. 

301 

302 Args: 

303 H: time dependent Hamiltonian function or time-independent Array. 

304 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

305 tlist: time list 

306 saveat_tlist: list of times at which to save the state. 

307 If -1 or [-1], save only at final time. 

308 If None, save at all times in tlist. Default: None. 

309 solver_options: SolverOptions with solver options 

310 

311 Returns: 

312 list of states 

313 """ 

314 

315 ψ = rho0 

316 

317 def f(t: float, ψₜ: Array, _): 

318 H_val = H(t) # type: ignore 

319 

320 ψₜ_dot = -1j * (H_val @ ψₜ) 

321 

322 return ψₜ_dot 

323 

324 ψ_test = f(0, ψ, None) 

325 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape 

326 

327 sol = solve(f, ψ, tlist, saveat_tlist, None, solver_options=solver_options) 

328 return sol.ys 

329 

330# ---- 

331 

332# propagators 

333# ---- 

334 

335def propagator( 

336 H: Union[Qarray, Callable[[float], Qarray]], 

337 ts: Union[float, Array], 

338 saveat_tlist: Optional[Array] = None, 

339 solver_options=None 

340): 

341 """ Generate the propagator for a time dependent Hamiltonian. 

342 

343 Args: 

344 H (Qarray or callable): 

345 A Qarray static Hamiltonian OR 

346 a function that takes a time argument and returns a Hamiltonian. 

347 ts (float or Array): 

348 A single time point or 

349 an Array of time points. 

350 saveat_tlist: list of times at which to save the state. 

351 If -1 or [-1], save only at final time. 

352 If None, save at all times in tlist. Default: None. 

353 

354 Returns: 

355 Qarray or List[Qarray]: 

356 The propagator for the Hamiltonian at time t. 

357 OR a list of propagators for the Hamiltonian at each time in t. 

358 

359 """ 

360 

361 

362 ts_is_scalar = robust_isscalar(ts) 

363 H_is_qarray = isinstance(H, Qarray) 

364 

365 if H_is_qarray: 

366 return (-1j * H * ts).expm() 

367 else: 

368 

369 if ts_is_scalar: 

370 H_first = H(0.0) 

371 if ts == 0: 

372 return identity_like(H_first) 

373 ts = jnp.array([0.0, ts]) 

374 else: 

375 H_first = H(ts[0]) 

376 

377 basis_states = multi_mode_basis_set(H_first.space_dims) 

378 results = sesolve(H, basis_states, ts, saveat_tlist=saveat_tlist) 

379 propagators_data = results.data.squeeze(-1) 

380 propagators = Qarray.create(propagators_data, dims=H_first.space_dims) 

381 

382 return propagators