Coverage for jaxquantum / core / solvers.py: 88%
128 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""Solvers"""
3from diffrax import (
4 diffeqsolve,
5 ODETerm,
6 SaveAt,
7 PIDController,
8 TqdmProgressMeter,
9 NoProgressMeter,
10)
11from flax import struct
12from jax import Array
13from typing import Callable, Optional, Union
14import diffrax
15import jax.numpy as jnp
16import warnings
17import tqdm
18import logging
21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data
22from jaxquantum.core.conversions import jnp2jqt
23from jaxquantum.core.operators import identity_like, multi_mode_basis_set
24from jaxquantum.utils.utils import robust_isscalar
26# ----
29@struct.dataclass
30class SolverOptions:
31 progress_meter: bool = struct.field(pytree_node=False)
32 solver: str = (struct.field(pytree_node=False),)
33 max_steps: int = (struct.field(pytree_node=False),)
34 rtol: float = (struct.field(pytree_node=False),)
35 atol: float = (struct.field(pytree_node=False),)
37 @classmethod
38 def create(
39 cls,
40 progress_meter: bool = True,
41 solver: str = "Tsit5",
42 max_steps: int = 100_000,
43 rtol: float = 1e-7,
44 atol: float = 1e-9,
45 ):
46 return cls(progress_meter, solver, max_steps, rtol, atol)
49class CustomProgressMeter(TqdmProgressMeter):
50 @staticmethod
51 def _init_bar() -> tqdm.tqdm:
52 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
53 return tqdm.tqdm(
54 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█"
55 )
58def solve(f, ρ0, tlist, saveat_tlist, args, solver_options: Optional[
59 SolverOptions] = None):
60 """Gets teh desired solver from diffrax.
62 Args:
63 f: function defining the ODE
64 ρ0: initial state
65 tlist: time list
66 saveat_tlist: list of times at which to save the state
67 pass in [-1] to save only at final time
68 args: additional arguments to f
69 solver_options: dictionary with solver options
71 Returns:
72 solution
73 """
75 # f and ts
76 term = ODETerm(f)
78 if saveat_tlist.shape[0] == 1 and saveat_tlist == -1:
79 saveat = SaveAt(t1=True)
80 else:
81 saveat = SaveAt(ts=saveat_tlist)
83 # solver
84 solver_options = solver_options or SolverOptions.create()
86 solver_name = solver_options.solver
87 solver = getattr(diffrax, solver_name)()
88 stepsize_controller = PIDController(rtol=solver_options.rtol, atol=solver_options.atol)
90 # solve!
91 with warnings.catch_warnings():
92 warnings.filterwarnings("ignore",
93 message="Complex dtype support in Diffrax",
94 category=UserWarning) # NOTE: suppresses complex dtype warning in diffrax
95 sol = diffeqsolve(
96 term,
97 solver,
98 t0=tlist[0],
99 t1=tlist[-1],
100 dt0=tlist[1] - tlist[0],
101 y0=ρ0,
102 saveat=saveat,
103 stepsize_controller=stepsize_controller,
104 args=args,
105 max_steps=solver_options.max_steps,
106 progress_meter=CustomProgressMeter()
107 if solver_options.progress_meter
108 else NoProgressMeter(),
109 )
111 return sol
114def mesolve(
115 H: Union[Qarray, Callable[[float], Qarray]],
116 rho0: Qarray,
117 tlist: Array,
118 saveat_tlist: Optional[Array] = None,
119 c_ops: Optional[Qarray] = None,
120 solver_options: Optional[SolverOptions] = None,
121) -> Qarray:
122 """Quantum Master Equation solver.
124 Args:
125 H: time dependent Hamiltonian function or time-independent Qarray.
126 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
127 tlist: time list
128 saveat_tlist: list of times at which to save the state.
129 If -1 or [-1], save only at final time.
130 If None, save at all times in tlist. Default: None.
131 c_ops: qarray list of collapse operators
132 solver_options: SolverOptions with solver options
134 Returns:
135 list of states
136 """
138 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist
140 saveat_tlist = jnp.atleast_1d(saveat_tlist)
142 c_ops = c_ops if c_ops is not None else Qarray.from_list([])
144 # if isinstance(H, Qarray):
146 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper:
147 logging.warning(
148 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
149 )
151 ρ0 = rho0.to_dm().to_dense()
153 if robust_isscalar(H):
154 H = H * identity_like(ρ0) # treat scalar H as a multiple of the identity
156 dims = ρ0.dims
157 ρ0 = ρ0.data
159 c_ops = c_ops.data
161 if isinstance(H, Qarray):
162 Ht_data = lambda t: H.data
163 else:
164 Ht_data = lambda t: H(t).data
166 ys = _mesolve_data(Ht_data, ρ0, tlist, saveat_tlist, c_ops,
167 solver_options=solver_options)
169 return jnp2jqt(ys, dims=dims)
172def _mesolve_data(
173 H: Callable[[float], Array],
174 rho0: Array,
175 tlist: Array,
176 saveat_tlist: Array,
177 c_ops: Optional[Qarray] = None,
178 solver_options: Optional[SolverOptions] = None,
179) -> Array:
180 """Quantum Master Equation solver.
182 Args:
183 H: time dependent Hamiltonian function or time-independent Array.
184 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
185 tlist: time list
186 saveat_tlist: list of times at which to save the state
187 If -1 or [-1], save only at final time.
188 If None, save at all times in tlist. Default: None.
189 c_ops: qarray list of collapse operators
190 solver_options: SolverOptions with solver options
192 Returns:
193 list of states
194 """
196 c_ops = c_ops if c_ops is not None else jnp.array([])
198 # check is in mesolve
199 # if len(c_ops) == 0 and not is_dm_data(rho0):
200 # logging.warning(
201 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
202 # )
204 ρ0 = rho0
206 # Shape inference: when c_ops contains batched operators (e.g. shape
207 # (1, B, N, N)), the initial state ρ0 must be broadcast to (B, N, N) so
208 # that the ODE RHS produces consistently shaped output.
209 #
210 # The output batch shape is the broadcast of:
211 # c_ops[0] batch dims → c_ops.shape[1:-2] (outer batch index stripped)
212 # H batch dims → H(0.0).shape[:-2]
213 # ρ0 batch dims → ρ0.shape[:-2]
214 # This is a pure shape calculation — no array values are materialised.
215 H0_shape = H(0.0).shape
216 if len(c_ops) == 0:
217 batch_shape = jnp.broadcast_shapes(H0_shape[:-2], ρ0.shape[:-2])
218 else:
219 # c_ops.shape[1:-2]: strip the outermost (c_op index) dim and the two
220 # matrix dims to get the batch dims that will be broadcast into ρ.
221 batch_shape = jnp.broadcast_shapes(
222 c_ops.shape[1:-2], H0_shape[:-2], ρ0.shape[:-2]
223 )
224 ρ0 = jnp.resize(ρ0, batch_shape + ρ0.shape[-2:]) # ensure correct shape
226 if len(c_ops) != 0:
227 c_ops_bdims = c_ops.shape[:-2]
228 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1])
230 # Precompute the adjoint once, outside the ODE hot-loop.
231 # dag_data dispatches to the correct impl (dense or sparse) automatically,
232 # so c_ops_dag is BCOO when c_ops is sparse and a dense array otherwise.
233 c_ops_dag = dag_data(c_ops) if len(c_ops) != 0 else c_ops
235 def f(
236 t: float,
237 rho: Array,
238 args,
239 ):
240 c_ops_val, c_ops_dag_val = args
241 H_val = H(t) # type: ignore
243 rho_dot = -1j * (H_val @ rho - rho @ H_val)
245 if len(c_ops_val) == 0:
246 return rho_dot
248 # Compute the Lindblad dissipator D[L](ρ) = L ρ L† - ½(L†L ρ + ρ L†L)
249 # using only (sparse L) @ (dense rho) operations to support BCOO
250 # collapse operators natively — no dense @ sparse required:
251 #
252 # L ρ L† = dag( L @ dag(L @ ρ) ) avoids the dense @ L† step
253 # L†L ρ = L† @ (L @ ρ) BCOO @ dense → dense ✓
254 # ρ L†L = dag(L†L ρ) dag of dense ✓ (ρ Hermitian)
255 Lrho = c_ops_val @ rho
256 LrhoLdag = dag_data(c_ops_val @ dag_data(Lrho))
257 LdagLrho = c_ops_dag_val @ Lrho
258 rhoLdagL = dag_data(LdagLrho)
260 rho_dot_delta = 0.5 * (2 * LrhoLdag - LdagLrho - rhoLdagL)
262 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0)
264 rho_dot += rho_dot_delta
266 return rho_dot
268 sol = solve(f, ρ0, tlist, saveat_tlist, (c_ops, c_ops_dag),
269 solver_options=solver_options)
271 return sol.ys
274def sesolve(
275 H: Union[Qarray, Callable[[float], Qarray]],
276 rho0: Qarray,
277 tlist: Array,
278 saveat_tlist: Optional[Array] = None,
279 solver_options: Optional[SolverOptions] = None,
280) -> Qarray:
281 """Schrödinger Equation solver.
283 Args:
284 H: time dependent Hamiltonian function or time-independent Qarray.
285 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
286 tlist: time list
287 saveat_tlist: list of times at which to save the state.
288 If -1 or [-1], save only at final time.
289 If None, save at all times in tlist. Default: None.
290 solver_options: SolverOptions with solver options
292 Returns:
293 list of states
294 """
296 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist
298 saveat_tlist = jnp.atleast_1d(saveat_tlist)
300 ψ = rho0
302 if ψ.qtype == Qtypes.oper:
303 raise ValueError(
304 "Please use `jqt.mesolve` for initial state inputs in density matrix form."
305 )
307 ψ = ψ.to_ket().to_dense()
309 if robust_isscalar(H):
310 H = H * identity_like(ψ) # treat scalar H as a multiple of the identity
312 dims = ψ.dims
313 ψ = ψ.data
315 if isinstance(H, Qarray):
316 Ht_data = lambda t: H.data
317 else:
318 Ht_data = lambda t: H(t).data
320 ys = _sesolve_data(Ht_data, ψ, tlist, saveat_tlist,
321 solver_options=solver_options)
323 return jnp2jqt(ys, dims=dims)
326def _sesolve_data(
327 H: Callable[[float], Array],
328 rho0: Array,
329 tlist: Array,
330 saveat_tlist: Array,
331 solver_options: Optional[SolverOptions] = None,
332):
333 """Schrödinger Equation solver.
335 Args:
336 H: time dependent Hamiltonian function or time-independent Array.
337 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
338 tlist: time list
339 saveat_tlist: list of times at which to save the state.
340 If -1 or [-1], save only at final time.
341 If None, save at all times in tlist. Default: None.
342 solver_options: SolverOptions with solver options
344 Returns:
345 list of states
346 """
348 ψ = rho0
350 def f(t: float, ψₜ: Array, _):
351 H_val = H(t) # type: ignore
353 ψₜ_dot = -1j * (H_val @ ψₜ)
355 return ψₜ_dot
357 ψ_test = f(0, ψ, None)
358 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape
360 sol = solve(f, ψ, tlist, saveat_tlist, None, solver_options=solver_options)
361 return sol.ys
363# ----
365# propagators
366# ----
368def propagator(
369 H: Union[Qarray, Callable[[float], Qarray]],
370 ts: Union[float, Array],
371 saveat_tlist: Optional[Array] = None,
372 solver_options=None
373):
374 """ Generate the propagator for a time dependent Hamiltonian.
376 Args:
377 H (Qarray or callable):
378 A Qarray static Hamiltonian OR
379 a function that takes a time argument and returns a Hamiltonian.
380 ts (float or Array):
381 A single time point or
382 an Array of time points.
383 saveat_tlist: list of times at which to save the state.
384 If -1 or [-1], save only at final time.
385 If None, save at all times in tlist. Default: None.
387 Returns:
388 Qarray or List[Qarray]:
389 The propagator for the Hamiltonian at time t.
390 OR a list of propagators for the Hamiltonian at each time in t.
392 """
395 ts_is_scalar = robust_isscalar(ts)
396 H_is_qarray = isinstance(H, Qarray)
398 if H_is_qarray:
399 return (-1j * H * ts).expm()
400 else:
402 if ts_is_scalar:
403 H_first = H(0.0)
404 if ts == 0:
405 return identity_like(H_first)
406 ts = jnp.array([0.0, ts])
407 else:
408 H_first = H(ts[0])
410 basis_states = multi_mode_basis_set(H_first.space_dims)
411 results = sesolve(H, basis_states, ts, saveat_tlist=saveat_tlist)
412 propagators_data = results.data.squeeze(-1).mT
413 propagators = Qarray.create(propagators_data, dims=H_first.space_dims)
415 return propagators