Coverage for jaxquantum/core/solvers.py: 87%
121 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Solvers"""
3from diffrax import (
4 diffeqsolve,
5 ODETerm,
6 SaveAt,
7 PIDController,
8 TqdmProgressMeter,
9 NoProgressMeter,
10)
11from flax import struct
12from jax import Array
13from typing import Callable, Optional, Union
14import diffrax
15import jax.numpy as jnp
16import warnings
17import tqdm
18import logging
21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data
22from jaxquantum.core.conversions import jnp2jqt
23from jaxquantum.core.operators import identity_like, multi_mode_basis_set
24from jaxquantum.utils.utils import robust_isscalar
26# ----
29@struct.dataclass
30class SolverOptions:
31 progress_meter: bool = struct.field(pytree_node=False)
32 solver: str = (struct.field(pytree_node=False),)
33 max_steps: int = (struct.field(pytree_node=False),)
34 rtol: float = (struct.field(pytree_node=False),)
35 atol: float = (struct.field(pytree_node=False),)
37 @classmethod
38 def create(
39 cls,
40 progress_meter: bool = True,
41 solver: str = "Tsit5",
42 max_steps: int = 100_000,
43 rtol: float = 1e-7,
44 atol: float = 1e-9,
45 ):
46 return cls(progress_meter, solver, max_steps, rtol, atol)
49class CustomProgressMeter(TqdmProgressMeter):
50 @staticmethod
51 def _init_bar() -> tqdm.tqdm:
52 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
53 return tqdm.tqdm(
54 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█"
55 )
58def solve(f, ρ0, tlist, saveat_tlist, args, solver_options: Optional[
59 SolverOptions] = None):
60 """Gets teh desired solver from diffrax.
62 Args:
63 solver_options: dictionary with solver options
65 Returns:
66 solution
67 """
69 # f and ts
70 term = ODETerm(f)
71 if saveat_tlist.shape[0] == 1 and saveat_tlist == tlist[-1]:
72 saveat = SaveAt(t1=True)
73 else:
74 saveat = SaveAt(ts=saveat_tlist)
76 # solver
77 solver_options = solver_options or SolverOptions.create()
79 solver_name = solver_options.solver
80 solver = getattr(diffrax, solver_name)()
81 stepsize_controller = PIDController(rtol=solver_options.rtol, atol=solver_options.atol)
83 # solve!
84 with warnings.catch_warnings():
85 warnings.filterwarnings("ignore",
86 message="Complex dtype support in Diffrax",
87 category=UserWarning) # NOTE: suppresses complex dtype warning in diffrax
88 sol = diffeqsolve(
89 term,
90 solver,
91 t0=tlist[0],
92 t1=tlist[-1],
93 dt0=tlist[1] - tlist[0],
94 y0=ρ0,
95 saveat=saveat,
96 stepsize_controller=stepsize_controller,
97 args=args,
98 max_steps=solver_options.max_steps,
99 progress_meter=CustomProgressMeter()
100 if solver_options.progress_meter
101 else NoProgressMeter(),
102 )
104 return sol
107def mesolve(
108 H: Union[Qarray, Callable[[float], Qarray]],
109 rho0: Qarray,
110 tlist: Array,
111 saveat_tlist: Optional[Array] = None,
112 c_ops: Optional[Qarray] = None,
113 solver_options: Optional[SolverOptions] = None,
114) -> Qarray:
115 """Quantum Master Equation solver.
117 Args:
118 H: time dependent Hamiltonian function or time-independent Qarray.
119 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
120 tlist: time list
121 saveat_tlist: list of times at which to save the state. If None, save at all times in tlist. Default: None.
122 c_ops: qarray list of collapse operators
123 solver_options: SolverOptions with solver options
125 Returns:
126 list of states
127 """
129 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist
131 saveat_tlist = jnp.atleast_1d(saveat_tlist)
133 c_ops = c_ops if c_ops is not None else Qarray.from_list([])
135 # if isinstance(H, Qarray):
137 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper:
138 logging.warning(
139 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
140 )
142 ρ0 = rho0.to_dm()
143 dims = ρ0.dims
144 ρ0 = ρ0.data
146 c_ops = c_ops.data
148 if isinstance(H, Qarray):
149 Ht_data = lambda t: H.data
150 else:
151 Ht_data = lambda t: H(t).data if H is not None else None
153 ys = _mesolve_data(Ht_data, ρ0, tlist, saveat_tlist, c_ops,
154 solver_options=solver_options)
156 return jnp2jqt(ys, dims=dims)
159def _mesolve_data(
160 H: Callable[[float], Array],
161 rho0: Array,
162 tlist: Array,
163 saveat_tlist: Array,
164 c_ops: Optional[Qarray] = None,
165 solver_options: Optional[SolverOptions] = None,
166) -> Array:
167 """Quantum Master Equation solver.
169 Args:
170 H: time dependent Hamiltonian function or time-independent Array.
171 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
172 tlist: time list
173 c_ops: qarray list of collapse operators
174 solver_options: SolverOptions with solver options
176 Returns:
177 list of states
178 """
180 c_ops = c_ops if c_ops is not None else jnp.array([])
182 # check is in mesolve
183 # if len(c_ops) == 0 and not is_dm_data(rho0):
184 # logging.warning(
185 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
186 # )
188 ρ0 = rho0 + 0.0j
190 if len(c_ops) == 0:
191 test_data = H(0.0) @ ρ0
192 else:
193 test_data = c_ops[0] @ H(0.0) @ ρ0
195 ρ0 = jnp.resize(ρ0, test_data.shape) # ensure correct shape
197 if len(c_ops) != 0:
198 c_ops_bdims = c_ops.shape[:-2]
199 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1])
201 def f(
202 t: float,
203 rho: Array,
204 c_ops_val: Array,
205 ):
206 H_val = H(t) # type: ignore
207 H_val = H_val + 0.0j
209 rho_dot = -1j * (H_val @ rho - rho @ H_val)
211 if len(c_ops_val) == 0:
212 return rho_dot
214 c_ops_val_dag = dag_data(c_ops_val)
216 rho_dot_delta = 0.5 * (
217 2 * c_ops_val @ rho @ c_ops_val_dag
218 - rho @ c_ops_val_dag @ c_ops_val
219 - c_ops_val_dag @ c_ops_val @ rho
220 )
222 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0)
224 rho_dot += rho_dot_delta
226 return rho_dot
228 sol = solve(f, ρ0, tlist, saveat_tlist, c_ops,
229 solver_options=solver_options)
231 return sol.ys
234def sesolve(
235 H: Union[Qarray, Callable[[float], Qarray]],
236 rho0: Qarray,
237 tlist: Array,
238 saveat_tlist: Optional[Array] = None,
239 solver_options: Optional[SolverOptions] = None,
240) -> Qarray:
241 """Schrödinger Equation solver.
243 Args:
244 H: time dependent Hamiltonian function or time-independent Qarray.
245 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
246 tlist: time list
247 saveat_tlist: list of times at which to save the state. If None, save at all times in tlist. Default: None.
248 solver_options: SolverOptions with solver options
250 Returns:
251 list of states
252 """
254 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist
256 saveat_tlist = jnp.atleast_1d(saveat_tlist)
258 ψ = rho0
260 if ψ.qtype == Qtypes.oper:
261 raise ValueError(
262 "Please use `jqt.mesolve` for initial state inputs in density matrix form."
263 )
265 ψ = ψ.to_ket()
266 dims = ψ.dims
267 ψ = ψ.data
269 if isinstance(H, Qarray):
270 Ht_data = lambda t: H.data
271 else:
272 Ht_data = lambda t: H(t).data if H is not None else None
274 ys = _sesolve_data(Ht_data, ψ, tlist, saveat_tlist,
275 solver_options=solver_options)
277 return jnp2jqt(ys, dims=dims)
280def _sesolve_data(
281 H: Callable[[float], Array],
282 rho0: Array,
283 tlist: Array,
284 saveat_tlist: Array,
285 solver_options: Optional[SolverOptions] = None,
286):
287 """Schrödinger Equation solver.
289 Args:
290 H: time dependent Hamiltonian function or time-independent Array.
291 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
292 tlist: time list
293 solver_options: SolverOptions with solver options
295 Returns:
296 list of states
297 """
299 ψ = rho0
300 ψ = ψ + 0.0j
302 def f(t: float, ψₜ: Array, _):
303 H_val = H(t) # type: ignore
304 H_val = H_val + 0.0j
306 ψₜ_dot = -1j * (H_val @ ψₜ)
308 return ψₜ_dot
310 ψ_test = f(0, ψ, None)
311 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape
313 sol = solve(f, ψ, tlist, saveat_tlist, None, solver_options=solver_options)
314 return sol.ys
316# ----
318# propagators
319# ----
321def propagator(
322 H: Union[Qarray, Callable[[float], Qarray]],
323 ts: Union[float, Array],
324 solver_options=None
325):
326 """ Generate the propagator for a time dependent Hamiltonian.
328 Args:
329 H (Qarray or callable):
330 A Qarray static Hamiltonian OR
331 a function that takes a time argument and returns a Hamiltonian.
332 ts (float or Array):
333 A single time point or
334 an Array of time points.
336 Returns:
337 Qarray or List[Qarray]:
338 The propagator for the Hamiltonian at time t.
339 OR a list of propagators for the Hamiltonian at each time in t.
341 """
344 ts_is_scalar = robust_isscalar(ts)
345 H_is_qarray = isinstance(H, Qarray)
347 if H_is_qarray:
348 return (-1j * H * ts).expm()
349 else:
351 if ts_is_scalar:
352 H_first = H(0.0)
353 if ts == 0:
354 return identity_like(H_first)
355 ts = jnp.array([0.0, ts])
356 else:
357 H_first = H(ts[0])
359 basis_states = multi_mode_basis_set(H_first.space_dims)
360 results = sesolve(H, basis_states, ts)
361 propagators_data = results.data.squeeze(-1)
362 propagators = Qarray.create(propagators_data, dims=H_first.space_dims)
364 return propagators