Coverage for jaxquantum/core/solvers.py: 86%
118 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1"""Solvers"""
3from diffrax import (
4 diffeqsolve,
5 ODETerm,
6 SaveAt,
7 PIDController,
8 TqdmProgressMeter,
9 NoProgressMeter,
10)
11from flax import struct
12from jax import Array
13from typing import Callable, Optional, Union
14import diffrax
15import jax.numpy as jnp
16import warnings
17import tqdm
18import logging
21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data
22from jaxquantum.core.conversions import jnp2jqt
23from jaxquantum.core.operators import identity_like, multi_mode_basis_set
24from jaxquantum.utils.utils import robust_isscalar
26# ----
29@struct.dataclass
30class SolverOptions:
31 progress_meter: bool = struct.field(pytree_node=False)
32 solver: str = (struct.field(pytree_node=False),)
33 max_steps: int = (struct.field(pytree_node=False),)
34 rtol: float = (struct.field(pytree_node=False),)
35 atol: float = (struct.field(pytree_node=False),)
37 @classmethod
38 def create(
39 cls,
40 progress_meter: bool = True,
41 solver: str = "Tsit5",
42 max_steps: int = 100_000,
43 rtol: float = 1e-7,
44 atol: float = 1e-9,
45 ):
46 return cls(progress_meter, solver, max_steps, rtol, atol)
49class CustomProgressMeter(TqdmProgressMeter):
50 @staticmethod
51 def _init_bar() -> tqdm.tqdm:
52 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
53 return tqdm.tqdm(
54 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█"
55 )
58def solve(f, ρ0, tlist, saveat_tlist, args, solver_options: Optional[
59 SolverOptions] = None):
60 """Gets teh desired solver from diffrax.
62 Args:
63 f: function defining the ODE
64 ρ0: initial state
65 tlist: time list
66 saveat_tlist: list of times at which to save the state
67 pass in [-1] to save only at final time
68 args: additional arguments to f
69 solver_options: dictionary with solver options
71 Returns:
72 solution
73 """
75 # f and ts
76 term = ODETerm(f)
78 if saveat_tlist.shape[0] == 1 and saveat_tlist == -1:
79 saveat = SaveAt(t1=True)
80 else:
81 saveat = SaveAt(ts=saveat_tlist)
83 # solver
84 solver_options = solver_options or SolverOptions.create()
86 solver_name = solver_options.solver
87 solver = getattr(diffrax, solver_name)()
88 stepsize_controller = PIDController(rtol=solver_options.rtol, atol=solver_options.atol)
90 # solve!
91 with warnings.catch_warnings():
92 warnings.filterwarnings("ignore",
93 message="Complex dtype support in Diffrax",
94 category=UserWarning) # NOTE: suppresses complex dtype warning in diffrax
95 sol = diffeqsolve(
96 term,
97 solver,
98 t0=tlist[0],
99 t1=tlist[-1],
100 dt0=tlist[1] - tlist[0],
101 y0=ρ0,
102 saveat=saveat,
103 stepsize_controller=stepsize_controller,
104 args=args,
105 max_steps=solver_options.max_steps,
106 progress_meter=CustomProgressMeter()
107 if solver_options.progress_meter
108 else NoProgressMeter(),
109 )
111 return sol
114def mesolve(
115 H: Union[Qarray, Callable[[float], Qarray]],
116 rho0: Qarray,
117 tlist: Array,
118 saveat_tlist: Optional[Array] = None,
119 c_ops: Optional[Qarray] = None,
120 solver_options: Optional[SolverOptions] = None,
121) -> Qarray:
122 """Quantum Master Equation solver.
124 Args:
125 H: time dependent Hamiltonian function or time-independent Qarray.
126 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
127 tlist: time list
128 saveat_tlist: list of times at which to save the state.
129 If -1 or [-1], save only at final time.
130 If None, save at all times in tlist. Default: None.
131 c_ops: qarray list of collapse operators
132 solver_options: SolverOptions with solver options
134 Returns:
135 list of states
136 """
138 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist
140 saveat_tlist = jnp.atleast_1d(saveat_tlist)
142 c_ops = c_ops if c_ops is not None else Qarray.from_list([])
144 # if isinstance(H, Qarray):
146 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper:
147 logging.warning(
148 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
149 )
151 ρ0 = rho0.to_dm()
152 dims = ρ0.dims
153 ρ0 = ρ0.data
155 c_ops = c_ops.data
157 if isinstance(H, Qarray):
158 Ht_data = lambda t: H.data
159 else:
160 Ht_data = lambda t: H(t).data
162 ys = _mesolve_data(Ht_data, ρ0, tlist, saveat_tlist, c_ops,
163 solver_options=solver_options)
165 return jnp2jqt(ys, dims=dims)
168def _mesolve_data(
169 H: Callable[[float], Array],
170 rho0: Array,
171 tlist: Array,
172 saveat_tlist: Array,
173 c_ops: Optional[Qarray] = None,
174 solver_options: Optional[SolverOptions] = None,
175) -> Array:
176 """Quantum Master Equation solver.
178 Args:
179 H: time dependent Hamiltonian function or time-independent Array.
180 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
181 tlist: time list
182 saveat_tlist: list of times at which to save the state
183 If -1 or [-1], save only at final time.
184 If None, save at all times in tlist. Default: None.
185 c_ops: qarray list of collapse operators
186 solver_options: SolverOptions with solver options
188 Returns:
189 list of states
190 """
192 c_ops = c_ops if c_ops is not None else jnp.array([])
194 # check is in mesolve
195 # if len(c_ops) == 0 and not is_dm_data(rho0):
196 # logging.warning(
197 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
198 # )
200 ρ0 = rho0
202 if len(c_ops) == 0:
203 test_data = H(0.0) @ ρ0
204 else:
205 test_data = c_ops[0] @ H(0.0) @ ρ0
207 ρ0 = jnp.resize(ρ0, test_data.shape) # ensure correct shape
209 if len(c_ops) != 0:
210 c_ops_bdims = c_ops.shape[:-2]
211 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1])
213 def f(
214 t: float,
215 rho: Array,
216 c_ops_val: Array,
217 ):
218 H_val = H(t) # type: ignore
220 rho_dot = -1j * (H_val @ rho - rho @ H_val)
222 if len(c_ops_val) == 0:
223 return rho_dot
225 c_ops_val_dag = dag_data(c_ops_val)
227 rho_dot_delta = 0.5 * (
228 2 * c_ops_val @ rho @ c_ops_val_dag
229 - rho @ c_ops_val_dag @ c_ops_val
230 - c_ops_val_dag @ c_ops_val @ rho
231 )
233 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0)
235 rho_dot += rho_dot_delta
237 return rho_dot
239 sol = solve(f, ρ0, tlist, saveat_tlist, c_ops,
240 solver_options=solver_options)
242 return sol.ys
245def sesolve(
246 H: Union[Qarray, Callable[[float], Qarray]],
247 rho0: Qarray,
248 tlist: Array,
249 saveat_tlist: Optional[Array] = None,
250 solver_options: Optional[SolverOptions] = None,
251) -> Qarray:
252 """Schrödinger Equation solver.
254 Args:
255 H: time dependent Hamiltonian function or time-independent Qarray.
256 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
257 tlist: time list
258 saveat_tlist: list of times at which to save the state.
259 If -1 or [-1], save only at final time.
260 If None, save at all times in tlist. Default: None.
261 solver_options: SolverOptions with solver options
263 Returns:
264 list of states
265 """
267 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist
269 saveat_tlist = jnp.atleast_1d(saveat_tlist)
271 ψ = rho0
273 if ψ.qtype == Qtypes.oper:
274 raise ValueError(
275 "Please use `jqt.mesolve` for initial state inputs in density matrix form."
276 )
278 ψ = ψ.to_ket()
279 dims = ψ.dims
280 ψ = ψ.data
282 if isinstance(H, Qarray):
283 Ht_data = lambda t: H.data
284 else:
285 Ht_data = lambda t: H(t).data
287 ys = _sesolve_data(Ht_data, ψ, tlist, saveat_tlist,
288 solver_options=solver_options)
290 return jnp2jqt(ys, dims=dims)
293def _sesolve_data(
294 H: Callable[[float], Array],
295 rho0: Array,
296 tlist: Array,
297 saveat_tlist: Array,
298 solver_options: Optional[SolverOptions] = None,
299):
300 """Schrödinger Equation solver.
302 Args:
303 H: time dependent Hamiltonian function or time-independent Array.
304 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
305 tlist: time list
306 saveat_tlist: list of times at which to save the state.
307 If -1 or [-1], save only at final time.
308 If None, save at all times in tlist. Default: None.
309 solver_options: SolverOptions with solver options
311 Returns:
312 list of states
313 """
315 ψ = rho0
317 def f(t: float, ψₜ: Array, _):
318 H_val = H(t) # type: ignore
320 ψₜ_dot = -1j * (H_val @ ψₜ)
322 return ψₜ_dot
324 ψ_test = f(0, ψ, None)
325 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape
327 sol = solve(f, ψ, tlist, saveat_tlist, None, solver_options=solver_options)
328 return sol.ys
330# ----
332# propagators
333# ----
335def propagator(
336 H: Union[Qarray, Callable[[float], Qarray]],
337 ts: Union[float, Array],
338 saveat_tlist: Optional[Array] = None,
339 solver_options=None
340):
341 """ Generate the propagator for a time dependent Hamiltonian.
343 Args:
344 H (Qarray or callable):
345 A Qarray static Hamiltonian OR
346 a function that takes a time argument and returns a Hamiltonian.
347 ts (float or Array):
348 A single time point or
349 an Array of time points.
350 saveat_tlist: list of times at which to save the state.
351 If -1 or [-1], save only at final time.
352 If None, save at all times in tlist. Default: None.
354 Returns:
355 Qarray or List[Qarray]:
356 The propagator for the Hamiltonian at time t.
357 OR a list of propagators for the Hamiltonian at each time in t.
359 """
362 ts_is_scalar = robust_isscalar(ts)
363 H_is_qarray = isinstance(H, Qarray)
365 if H_is_qarray:
366 return (-1j * H * ts).expm()
367 else:
369 if ts_is_scalar:
370 H_first = H(0.0)
371 if ts == 0:
372 return identity_like(H_first)
373 ts = jnp.array([0.0, ts])
374 else:
375 H_first = H(ts[0])
377 basis_states = multi_mode_basis_set(H_first.space_dims)
378 results = sesolve(H, basis_states, ts, saveat_tlist=saveat_tlist)
379 propagators_data = results.data.squeeze(-1)
380 propagators = Qarray.create(propagators_data, dims=H_first.space_dims)
382 return propagators