Coverage for jaxquantum/core/solvers.py: 100%
95 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Solvers"""
3from diffrax import (
4 diffeqsolve,
5 ODETerm,
6 SaveAt,
7 PIDController,
8 TqdmProgressMeter,
9 NoProgressMeter,
10)
11from flax import struct
12from jax import Array
13from typing import Callable, Optional, Union
14import diffrax
15import jax.numpy as jnp
16import warnings
17import tqdm
18import logging
21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data
22from jaxquantum.core.conversions import jnp2jqt
24# ----
27@struct.dataclass
28class SolverOptions:
29 progress_meter: bool = struct.field(pytree_node=False)
30 solver: str = (struct.field(pytree_node=False),)
31 max_steps: int = (struct.field(pytree_node=False),)
33 @classmethod
34 def create(
35 cls,
36 progress_meter: bool = True,
37 solver: str = "Tsit5",
38 max_steps: int = 100_000,
39 ):
40 return cls(progress_meter, solver, max_steps)
43class CustomProgressMeter(TqdmProgressMeter):
44 @staticmethod
45 def _init_bar() -> tqdm.tqdm:
46 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
47 return tqdm.tqdm(
48 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█"
49 )
52def solve(f, ρ0, tlist, args, solver_options: Optional[SolverOptions] = None):
53 """Gets teh desired solver from diffrax.
55 Args:
56 solver_options: dictionary with solver options
58 Returns:
59 solution
60 """
62 # f and ts
63 term = ODETerm(f)
64 saveat = SaveAt(ts=tlist)
66 # solver
67 solver_options = solver_options or SolverOptions.create()
69 solver_name = solver_options.solver
70 solver = getattr(diffrax, solver_name)()
71 stepsize_controller = PIDController(rtol=1e-6, atol=1e-6)
73 # solve!
74 with warnings.catch_warnings():
75 warnings.simplefilter(
76 "ignore", UserWarning
77 ) # NOTE: suppresses complex dtype warning in diffrax
78 sol = diffeqsolve(
79 term,
80 solver,
81 t0=tlist[0],
82 t1=tlist[-1],
83 dt0=tlist[1] - tlist[0],
84 y0=ρ0,
85 saveat=saveat,
86 stepsize_controller=stepsize_controller,
87 args=args,
88 max_steps=solver_options.max_steps,
89 progress_meter=CustomProgressMeter()
90 if solver_options.progress_meter
91 else NoProgressMeter(),
92 )
94 return sol
97def mesolve(
98 H: Union[Qarray, Callable[[float], Qarray]],
99 rho0: Qarray,
100 tlist: Array,
101 c_ops: Optional[Qarray] = None,
102 solver_options: Optional[SolverOptions] = None,
103) -> Qarray:
104 """Quantum Master Equation solver.
106 Args:
107 H: time dependent Hamiltonian function or time-independent Qarray.
108 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
109 tlist: time list
110 c_ops: qarray list of collapse operators
111 solver_options: SolverOptions with solver options
113 Returns:
114 list of states
115 """
117 c_ops = c_ops if c_ops is not None else Qarray.from_list([])
119 # if isinstance(H, Qarray):
121 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper:
122 logging.warning(
123 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
124 )
126 ρ0 = rho0.to_dm()
127 dims = ρ0.dims
128 ρ0 = ρ0.data
130 c_ops = c_ops.data
132 if isinstance(H, Qarray):
133 Ht_data = lambda t: H.data
134 else:
135 Ht_data = lambda t: H(t).data if H is not None else None
137 ys = _mesolve_data(Ht_data, ρ0, tlist, c_ops, solver_options=solver_options)
139 return jnp2jqt(ys, dims=dims)
142def _mesolve_data(
143 H: Callable[[float], Array],
144 rho0: Array,
145 tlist: Array,
146 c_ops: Optional[Qarray] = None,
147 solver_options: Optional[SolverOptions] = None,
148) -> Array:
149 """Quantum Master Equation solver.
151 Args:
152 H: time dependent Hamiltonian function or time-independent Array.
153 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
154 tlist: time list
155 c_ops: qarray list of collapse operators
156 solver_options: SolverOptions with solver options
158 Returns:
159 list of states
160 """
162 c_ops = c_ops if c_ops is not None else jnp.array([])
164 # check is in mesolve
165 # if len(c_ops) == 0 and not is_dm_data(rho0):
166 # logging.warning(
167 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix."
168 # )
170 ρ0 = rho0 + 0.0j
172 if len(c_ops) == 0:
173 test_data = H(0.0) @ ρ0
174 else:
175 test_data = c_ops[0] @ H(0.0) @ ρ0
177 ρ0 = jnp.resize(ρ0, test_data.shape) # ensure correct shape
179 if len(c_ops) != 0:
180 c_ops_bdims = c_ops.shape[:-2]
181 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1])
183 def f(
184 t: float,
185 rho: Array,
186 c_ops_val: Array,
187 ):
188 H_val = H(t) # type: ignore
189 H_val = H_val + 0.0j
191 rho_dot = -1j * (H_val @ rho - rho @ H_val)
193 if len(c_ops_val) == 0:
194 return rho_dot
196 c_ops_val_dag = dag_data(c_ops_val)
198 rho_dot_delta = 0.5 * (
199 2 * c_ops_val @ rho @ c_ops_val_dag
200 - rho @ c_ops_val_dag @ c_ops_val
201 - c_ops_val_dag @ c_ops_val @ rho
202 )
204 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0)
206 rho_dot += rho_dot_delta
208 return rho_dot
210 sol = solve(f, ρ0, tlist, c_ops, solver_options=solver_options)
212 return sol.ys
215def sesolve(
216 H: Union[Qarray, Callable[[float], Qarray]],
217 rho0: Qarray,
218 tlist: Array,
219 solver_options: Optional[SolverOptions] = None,
220) -> Qarray:
221 """Schrödinger Equation solver.
223 Args:
224 H: time dependent Hamiltonian function or time-independent Qarray.
225 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
226 tlist: time list
227 solver_options: SolverOptions with solver options
229 Returns:
230 list of states
231 """
233 ψ = rho0
235 if ψ.qtype == Qtypes.oper:
236 raise ValueError(
237 "Please use `jqt.mesolve` for initial state inputs in density matrix form."
238 )
240 ψ = ψ.to_ket()
241 dims = ψ.dims
242 ψ = ψ.data
244 if isinstance(H, Qarray):
245 Ht_data = lambda t: H.data
246 else:
247 Ht_data = lambda t: H(t).data if H is not None else None
249 ys = _sesolve_data(Ht_data, ψ, tlist, solver_options=solver_options)
251 return jnp2jqt(ys, dims=dims)
254def _sesolve_data(
255 H: Callable[[float], Array],
256 rho0: Array,
257 tlist: Array,
258 solver_options: Optional[SolverOptions] = None,
259):
260 """Schrödinger Equation solver.
262 Args:
263 H: time dependent Hamiltonian function or time-independent Array.
264 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve.
265 tlist: time list
266 solver_options: SolverOptions with solver options
268 Returns:
269 list of states
270 """
272 ψ = rho0
273 ψ = ψ + 0.0j
275 def f(t: float, ψₜ: Array, _):
276 H_val = H(t) # type: ignore
277 H_val = H_val + 0.0j
279 ψₜ_dot = -1j * (H_val @ ψₜ)
281 return ψₜ_dot
283 ψ_test = f(0, ψ, None)
284 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape
286 sol = solve(f, ψ, tlist, None, solver_options=solver_options)
287 return sol.ys
289 # ----
291 # propagators
292 # ----
294 # def propagator(
295 # H: Union[Qarray, Callable[[float], Qarray]],
296 # t: Union[float, Array],
297 # solver_options=None
298 # ):
299 # """ Generate the propagator for a time dependent Hamiltonian.
301 # Args:
302 # H (Qarray or callable):
303 # A Qarray static Hamiltonian OR
304 # a function that takes a time argument and returns a Hamiltonian.
305 # ts (float or Array):
306 # A single time point or
307 # an Array of time points.
309 # Returns:
310 # Qarray or List[Qarray]:
311 # The propagator for the Hamiltonian at time t.
312 # OR a list of propagators for the Hamiltonian at each time in t.
314 # """
316 # t_is_scalar = robust_isscalar(t)
318 # if isinstance(H, Qarray):
319 # dims = H.dims
320 # if t_is_scalar:
321 # if t == 0:
322 # return identity_like(H)
324 # return jnp2jqt(propagator_0_data(H.data,t), dims=dims)
325 # else:
326 # f = lambda t: propagator_0_data(H.data,t)
327 # return jnp2jqt(vmap(f)(t), dims)
328 # else:
329 # dims = H(0.0).dims
330 # H_data = lambda t: H(t).data
331 # if t_is_scalar:
332 # if t == 0:
333 # return identity_like(H(0.0))
335 # ts = jnp.linspace(0,t,2)
336 # return jnp2jqt(
337 # propagator_t_data(H_data, ts, solver_options=solver_options)[1],
338 # dims=dims
339 # )
340 # else:
341 # ts = t
342 # U_props = propagator_t_data(H_data, ts, solver_options=solver_options)
343 # return jnp2jqt(U_props, dims)
345 # def propagator_0_data(
346 # H0: Array,
347 # t: float
348 # ):
349 # """ Generate the propagator for a time independent Hamiltonian.
351 # Args:
352 # H0 (Qarray): The Hamiltonian.
354 # Returns:
355 # Qarray: The propagator for the time independent Hamiltonian.
356 # """
357 # return jsp.linalg.expm(-1j * H0 * t)
359 # def propagator_t_data(
360 # Ht: Callable[[float], Array],
361 # ts: Array,
362 # solver_options=None
363 # ):
364 """ Generate the propagator for a time dependent Hamiltonian.
366 Args:
367 ts (float): The final time of the propagator.
368 Warning: Do not send in t. In this case, just do exp(-1j*Ht(0.0)).
369 Ht (callable): A function that takes a time argument and returns a Hamiltonian.
370 solver_options (dict): Options to pass to the solver.
372 Returns:
373 Qarray: The propagator for the time dependent Hamiltonian for the time range [0, t_final].
374 """
375 # N = Ht(0).shape[0]
376 # basis_states = jnp.eye(N)
378 # def propogate_state(initial_state):
379 # return sesolve_data(initial_state, ts, Ht=Ht, solver_options=solver_options)
381 # U_prop = vmap(propogate_state)(basis_states)
382 # U_prop = U_prop.transpose(1,0,2) # move time axis to the front
383 # return U_prop