Coverage for jaxquantum / core / solvers.py: 88%

128 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""Solvers""" 

2 

3from diffrax import ( 

4 diffeqsolve, 

5 ODETerm, 

6 SaveAt, 

7 PIDController, 

8 TqdmProgressMeter, 

9 NoProgressMeter, 

10) 

11from flax import struct 

12from jax import Array 

13from typing import Callable, Optional, Union 

14import diffrax 

15import jax.numpy as jnp 

16import warnings 

17import tqdm 

18import logging 

19 

20 

21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data 

22from jaxquantum.core.conversions import jnp2jqt 

23from jaxquantum.core.operators import identity_like, multi_mode_basis_set 

24from jaxquantum.utils.utils import robust_isscalar 

25 

26# ---- 

27 

28 

29@struct.dataclass 

30class SolverOptions: 

31 progress_meter: bool = struct.field(pytree_node=False) 

32 solver: str = (struct.field(pytree_node=False),) 

33 max_steps: int = (struct.field(pytree_node=False),) 

34 rtol: float = (struct.field(pytree_node=False),) 

35 atol: float = (struct.field(pytree_node=False),) 

36 

37 @classmethod 

38 def create( 

39 cls, 

40 progress_meter: bool = True, 

41 solver: str = "Tsit5", 

42 max_steps: int = 100_000, 

43 rtol: float = 1e-7, 

44 atol: float = 1e-9, 

45 ): 

46 return cls(progress_meter, solver, max_steps, rtol, atol) 

47 

48 

49class CustomProgressMeter(TqdmProgressMeter): 

50 @staticmethod 

51 def _init_bar() -> tqdm.tqdm: 

52 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" 

53 return tqdm.tqdm( 

54 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█" 

55 ) 

56 

57 

58def solve(f, ρ0, tlist, saveat_tlist, args, solver_options: Optional[ 

59 SolverOptions] = None): 

60 """Gets teh desired solver from diffrax. 

61 

62 Args: 

63 f: function defining the ODE 

64 ρ0: initial state 

65 tlist: time list 

66 saveat_tlist: list of times at which to save the state 

67 pass in [-1] to save only at final time 

68 args: additional arguments to f 

69 solver_options: dictionary with solver options 

70 

71 Returns: 

72 solution 

73 """ 

74 

75 # f and ts 

76 term = ODETerm(f) 

77 

78 if saveat_tlist.shape[0] == 1 and saveat_tlist == -1: 

79 saveat = SaveAt(t1=True) 

80 else: 

81 saveat = SaveAt(ts=saveat_tlist) 

82 

83 # solver 

84 solver_options = solver_options or SolverOptions.create() 

85 

86 solver_name = solver_options.solver 

87 solver = getattr(diffrax, solver_name)() 

88 stepsize_controller = PIDController(rtol=solver_options.rtol, atol=solver_options.atol) 

89 

90 # solve! 

91 with warnings.catch_warnings(): 

92 warnings.filterwarnings("ignore", 

93 message="Complex dtype support in Diffrax", 

94 category=UserWarning) # NOTE: suppresses complex dtype warning in diffrax 

95 sol = diffeqsolve( 

96 term, 

97 solver, 

98 t0=tlist[0], 

99 t1=tlist[-1], 

100 dt0=tlist[1] - tlist[0], 

101 y0=ρ0, 

102 saveat=saveat, 

103 stepsize_controller=stepsize_controller, 

104 args=args, 

105 max_steps=solver_options.max_steps, 

106 progress_meter=CustomProgressMeter() 

107 if solver_options.progress_meter 

108 else NoProgressMeter(), 

109 ) 

110 

111 return sol 

112 

113 

114def mesolve( 

115 H: Union[Qarray, Callable[[float], Qarray]], 

116 rho0: Qarray, 

117 tlist: Array, 

118 saveat_tlist: Optional[Array] = None, 

119 c_ops: Optional[Qarray] = None, 

120 solver_options: Optional[SolverOptions] = None, 

121) -> Qarray: 

122 """Quantum Master Equation solver. 

123 

124 Args: 

125 H: time dependent Hamiltonian function or time-independent Qarray. 

126 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

127 tlist: time list 

128 saveat_tlist: list of times at which to save the state. 

129 If -1 or [-1], save only at final time. 

130 If None, save at all times in tlist. Default: None. 

131 c_ops: qarray list of collapse operators 

132 solver_options: SolverOptions with solver options 

133 

134 Returns: 

135 list of states 

136 """ 

137 

138 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist 

139 

140 saveat_tlist = jnp.atleast_1d(saveat_tlist) 

141 

142 c_ops = c_ops if c_ops is not None else Qarray.from_list([]) 

143 

144 # if isinstance(H, Qarray): 

145 

146 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper: 

147 logging.warning( 

148 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

149 ) 

150 

151 ρ0 = rho0.to_dm().to_dense() 

152 

153 if robust_isscalar(H): 

154 H = H * identity_like(ρ0) # treat scalar H as a multiple of the identity 

155 

156 dims = ρ0.dims 

157 ρ0 = ρ0.data 

158 

159 c_ops = c_ops.data 

160 

161 if isinstance(H, Qarray): 

162 Ht_data = lambda t: H.data 

163 else: 

164 Ht_data = lambda t: H(t).data 

165 

166 ys = _mesolve_data(Ht_data, ρ0, tlist, saveat_tlist, c_ops, 

167 solver_options=solver_options) 

168 

169 return jnp2jqt(ys, dims=dims) 

170 

171 

172def _mesolve_data( 

173 H: Callable[[float], Array], 

174 rho0: Array, 

175 tlist: Array, 

176 saveat_tlist: Array, 

177 c_ops: Optional[Qarray] = None, 

178 solver_options: Optional[SolverOptions] = None, 

179) -> Array: 

180 """Quantum Master Equation solver. 

181 

182 Args: 

183 H: time dependent Hamiltonian function or time-independent Array. 

184 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

185 tlist: time list 

186 saveat_tlist: list of times at which to save the state 

187 If -1 or [-1], save only at final time. 

188 If None, save at all times in tlist. Default: None. 

189 c_ops: qarray list of collapse operators 

190 solver_options: SolverOptions with solver options 

191 

192 Returns: 

193 list of states 

194 """ 

195 

196 c_ops = c_ops if c_ops is not None else jnp.array([]) 

197 

198 # check is in mesolve 

199 # if len(c_ops) == 0 and not is_dm_data(rho0): 

200 # logging.warning( 

201 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

202 # ) 

203 

204 ρ0 = rho0 

205 

206 # Shape inference: when c_ops contains batched operators (e.g. shape 

207 # (1, B, N, N)), the initial state ρ0 must be broadcast to (B, N, N) so 

208 # that the ODE RHS produces consistently shaped output. 

209 # 

210 # The output batch shape is the broadcast of: 

211 # c_ops[0] batch dims → c_ops.shape[1:-2] (outer batch index stripped) 

212 # H batch dims → H(0.0).shape[:-2] 

213 # ρ0 batch dims → ρ0.shape[:-2] 

214 # This is a pure shape calculation — no array values are materialised. 

215 H0_shape = H(0.0).shape 

216 if len(c_ops) == 0: 

217 batch_shape = jnp.broadcast_shapes(H0_shape[:-2], ρ0.shape[:-2]) 

218 else: 

219 # c_ops.shape[1:-2]: strip the outermost (c_op index) dim and the two 

220 # matrix dims to get the batch dims that will be broadcast into ρ. 

221 batch_shape = jnp.broadcast_shapes( 

222 c_ops.shape[1:-2], H0_shape[:-2], ρ0.shape[:-2] 

223 ) 

224 ρ0 = jnp.resize(ρ0, batch_shape + ρ0.shape[-2:]) # ensure correct shape 

225 

226 if len(c_ops) != 0: 

227 c_ops_bdims = c_ops.shape[:-2] 

228 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1]) 

229 

230 # Precompute the adjoint once, outside the ODE hot-loop. 

231 # dag_data dispatches to the correct impl (dense or sparse) automatically, 

232 # so c_ops_dag is BCOO when c_ops is sparse and a dense array otherwise. 

233 c_ops_dag = dag_data(c_ops) if len(c_ops) != 0 else c_ops 

234 

235 def f( 

236 t: float, 

237 rho: Array, 

238 args, 

239 ): 

240 c_ops_val, c_ops_dag_val = args 

241 H_val = H(t) # type: ignore 

242 

243 rho_dot = -1j * (H_val @ rho - rho @ H_val) 

244 

245 if len(c_ops_val) == 0: 

246 return rho_dot 

247 

248 # Compute the Lindblad dissipator D[L](ρ) = L ρ L† - ½(L†L ρ + ρ L†L) 

249 # using only (sparse L) @ (dense rho) operations to support BCOO 

250 # collapse operators natively — no dense @ sparse required: 

251 # 

252 # L ρ L† = dag( L @ dag(L @ ρ) ) avoids the dense @ L† step 

253 # L†L ρ = L† @ (L @ ρ) BCOO @ dense → dense ✓ 

254 # ρ L†L = dag(L†L ρ) dag of dense ✓ (ρ Hermitian) 

255 Lrho = c_ops_val @ rho 

256 LrhoLdag = dag_data(c_ops_val @ dag_data(Lrho)) 

257 LdagLrho = c_ops_dag_val @ Lrho 

258 rhoLdagL = dag_data(LdagLrho) 

259 

260 rho_dot_delta = 0.5 * (2 * LrhoLdag - LdagLrho - rhoLdagL) 

261 

262 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0) 

263 

264 rho_dot += rho_dot_delta 

265 

266 return rho_dot 

267 

268 sol = solve(f, ρ0, tlist, saveat_tlist, (c_ops, c_ops_dag), 

269 solver_options=solver_options) 

270 

271 return sol.ys 

272 

273 

274def sesolve( 

275 H: Union[Qarray, Callable[[float], Qarray]], 

276 rho0: Qarray, 

277 tlist: Array, 

278 saveat_tlist: Optional[Array] = None, 

279 solver_options: Optional[SolverOptions] = None, 

280) -> Qarray: 

281 """Schrödinger Equation solver. 

282 

283 Args: 

284 H: time dependent Hamiltonian function or time-independent Qarray. 

285 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

286 tlist: time list 

287 saveat_tlist: list of times at which to save the state. 

288 If -1 or [-1], save only at final time. 

289 If None, save at all times in tlist. Default: None. 

290 solver_options: SolverOptions with solver options 

291 

292 Returns: 

293 list of states 

294 """ 

295 

296 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist 

297 

298 saveat_tlist = jnp.atleast_1d(saveat_tlist) 

299 

300 ψ = rho0 

301 

302 if ψ.qtype == Qtypes.oper: 

303 raise ValueError( 

304 "Please use `jqt.mesolve` for initial state inputs in density matrix form." 

305 ) 

306 

307 ψ = ψ.to_ket().to_dense() 

308 

309 if robust_isscalar(H): 

310 H = H * identity_like(ψ) # treat scalar H as a multiple of the identity 

311 

312 dims = ψ.dims 

313 ψ = ψ.data 

314 

315 if isinstance(H, Qarray): 

316 Ht_data = lambda t: H.data 

317 else: 

318 Ht_data = lambda t: H(t).data 

319 

320 ys = _sesolve_data(Ht_data, ψ, tlist, saveat_tlist, 

321 solver_options=solver_options) 

322 

323 return jnp2jqt(ys, dims=dims) 

324 

325 

326def _sesolve_data( 

327 H: Callable[[float], Array], 

328 rho0: Array, 

329 tlist: Array, 

330 saveat_tlist: Array, 

331 solver_options: Optional[SolverOptions] = None, 

332): 

333 """Schrödinger Equation solver. 

334 

335 Args: 

336 H: time dependent Hamiltonian function or time-independent Array. 

337 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

338 tlist: time list 

339 saveat_tlist: list of times at which to save the state. 

340 If -1 or [-1], save only at final time. 

341 If None, save at all times in tlist. Default: None. 

342 solver_options: SolverOptions with solver options 

343 

344 Returns: 

345 list of states 

346 """ 

347 

348 ψ = rho0 

349 

350 def f(t: float, ψₜ: Array, _): 

351 H_val = H(t) # type: ignore 

352 

353 ψₜ_dot = -1j * (H_val @ ψₜ) 

354 

355 return ψₜ_dot 

356 

357 ψ_test = f(0, ψ, None) 

358 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape 

359 

360 sol = solve(f, ψ, tlist, saveat_tlist, None, solver_options=solver_options) 

361 return sol.ys 

362 

363# ---- 

364 

365# propagators 

366# ---- 

367 

368def propagator( 

369 H: Union[Qarray, Callable[[float], Qarray]], 

370 ts: Union[float, Array], 

371 saveat_tlist: Optional[Array] = None, 

372 solver_options=None 

373): 

374 """ Generate the propagator for a time dependent Hamiltonian. 

375 

376 Args: 

377 H (Qarray or callable): 

378 A Qarray static Hamiltonian OR 

379 a function that takes a time argument and returns a Hamiltonian. 

380 ts (float or Array): 

381 A single time point or 

382 an Array of time points. 

383 saveat_tlist: list of times at which to save the state. 

384 If -1 or [-1], save only at final time. 

385 If None, save at all times in tlist. Default: None. 

386 

387 Returns: 

388 Qarray or List[Qarray]: 

389 The propagator for the Hamiltonian at time t. 

390 OR a list of propagators for the Hamiltonian at each time in t. 

391 

392 """ 

393 

394 

395 ts_is_scalar = robust_isscalar(ts) 

396 H_is_qarray = isinstance(H, Qarray) 

397 

398 if H_is_qarray: 

399 return (-1j * H * ts).expm() 

400 else: 

401 

402 if ts_is_scalar: 

403 H_first = H(0.0) 

404 if ts == 0: 

405 return identity_like(H_first) 

406 ts = jnp.array([0.0, ts]) 

407 else: 

408 H_first = H(ts[0]) 

409 

410 basis_states = multi_mode_basis_set(H_first.space_dims) 

411 results = sesolve(H, basis_states, ts, saveat_tlist=saveat_tlist) 

412 propagators_data = results.data.squeeze(-1).mT 

413 propagators = Qarray.create(propagators_data, dims=H_first.space_dims) 

414 

415 return propagators