Coverage for jaxquantum/core/solvers.py: 87%

121 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Solvers""" 

2 

3from diffrax import ( 

4 diffeqsolve, 

5 ODETerm, 

6 SaveAt, 

7 PIDController, 

8 TqdmProgressMeter, 

9 NoProgressMeter, 

10) 

11from flax import struct 

12from jax import Array 

13from typing import Callable, Optional, Union 

14import diffrax 

15import jax.numpy as jnp 

16import warnings 

17import tqdm 

18import logging 

19 

20 

21from jaxquantum.core.qarray import Qarray, Qtypes, dag_data 

22from jaxquantum.core.conversions import jnp2jqt 

23from jaxquantum.core.operators import identity_like, multi_mode_basis_set 

24from jaxquantum.utils.utils import robust_isscalar 

25 

26# ---- 

27 

28 

29@struct.dataclass 

30class SolverOptions: 

31 progress_meter: bool = struct.field(pytree_node=False) 

32 solver: str = (struct.field(pytree_node=False),) 

33 max_steps: int = (struct.field(pytree_node=False),) 

34 rtol: float = (struct.field(pytree_node=False),) 

35 atol: float = (struct.field(pytree_node=False),) 

36 

37 @classmethod 

38 def create( 

39 cls, 

40 progress_meter: bool = True, 

41 solver: str = "Tsit5", 

42 max_steps: int = 100_000, 

43 rtol: float = 1e-7, 

44 atol: float = 1e-9, 

45 ): 

46 return cls(progress_meter, solver, max_steps, rtol, atol) 

47 

48 

49class CustomProgressMeter(TqdmProgressMeter): 

50 @staticmethod 

51 def _init_bar() -> tqdm.tqdm: 

52 bar_format = "{desc}: {percentage:3.0f}% |{bar}| [{elapsed}<{remaining}, {rate_fmt}{postfix}]" 

53 return tqdm.tqdm( 

54 total=100, bar_format=bar_format, unit="%", colour="MAGENTA", ascii="░▒█" 

55 ) 

56 

57 

58def solve(f, ρ0, tlist, saveat_tlist, args, solver_options: Optional[ 

59 SolverOptions] = None): 

60 """Gets teh desired solver from diffrax. 

61 

62 Args: 

63 solver_options: dictionary with solver options 

64 

65 Returns: 

66 solution 

67 """ 

68 

69 # f and ts 

70 term = ODETerm(f) 

71 if saveat_tlist.shape[0] == 1 and saveat_tlist == tlist[-1]: 

72 saveat = SaveAt(t1=True) 

73 else: 

74 saveat = SaveAt(ts=saveat_tlist) 

75 

76 # solver 

77 solver_options = solver_options or SolverOptions.create() 

78 

79 solver_name = solver_options.solver 

80 solver = getattr(diffrax, solver_name)() 

81 stepsize_controller = PIDController(rtol=solver_options.rtol, atol=solver_options.atol) 

82 

83 # solve! 

84 with warnings.catch_warnings(): 

85 warnings.filterwarnings("ignore", 

86 message="Complex dtype support in Diffrax", 

87 category=UserWarning) # NOTE: suppresses complex dtype warning in diffrax 

88 sol = diffeqsolve( 

89 term, 

90 solver, 

91 t0=tlist[0], 

92 t1=tlist[-1], 

93 dt0=tlist[1] - tlist[0], 

94 y0=ρ0, 

95 saveat=saveat, 

96 stepsize_controller=stepsize_controller, 

97 args=args, 

98 max_steps=solver_options.max_steps, 

99 progress_meter=CustomProgressMeter() 

100 if solver_options.progress_meter 

101 else NoProgressMeter(), 

102 ) 

103 

104 return sol 

105 

106 

107def mesolve( 

108 H: Union[Qarray, Callable[[float], Qarray]], 

109 rho0: Qarray, 

110 tlist: Array, 

111 saveat_tlist: Optional[Array] = None, 

112 c_ops: Optional[Qarray] = None, 

113 solver_options: Optional[SolverOptions] = None, 

114) -> Qarray: 

115 """Quantum Master Equation solver. 

116 

117 Args: 

118 H: time dependent Hamiltonian function or time-independent Qarray. 

119 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

120 tlist: time list 

121 saveat_tlist: list of times at which to save the state. If None, save at all times in tlist. Default: None. 

122 c_ops: qarray list of collapse operators 

123 solver_options: SolverOptions with solver options 

124 

125 Returns: 

126 list of states 

127 """ 

128 

129 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist 

130 

131 saveat_tlist = jnp.atleast_1d(saveat_tlist) 

132 

133 c_ops = c_ops if c_ops is not None else Qarray.from_list([]) 

134 

135 # if isinstance(H, Qarray): 

136 

137 if len(c_ops) == 0 and rho0.qtype != Qtypes.oper: 

138 logging.warning( 

139 "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

140 ) 

141 

142 ρ0 = rho0.to_dm() 

143 dims = ρ0.dims 

144 ρ0 = ρ0.data 

145 

146 c_ops = c_ops.data 

147 

148 if isinstance(H, Qarray): 

149 Ht_data = lambda t: H.data 

150 else: 

151 Ht_data = lambda t: H(t).data if H is not None else None 

152 

153 ys = _mesolve_data(Ht_data, ρ0, tlist, saveat_tlist, c_ops, 

154 solver_options=solver_options) 

155 

156 return jnp2jqt(ys, dims=dims) 

157 

158 

159def _mesolve_data( 

160 H: Callable[[float], Array], 

161 rho0: Array, 

162 tlist: Array, 

163 saveat_tlist: Array, 

164 c_ops: Optional[Qarray] = None, 

165 solver_options: Optional[SolverOptions] = None, 

166) -> Array: 

167 """Quantum Master Equation solver. 

168 

169 Args: 

170 H: time dependent Hamiltonian function or time-independent Array. 

171 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

172 tlist: time list 

173 c_ops: qarray list of collapse operators 

174 solver_options: SolverOptions with solver options 

175 

176 Returns: 

177 list of states 

178 """ 

179 

180 c_ops = c_ops if c_ops is not None else jnp.array([]) 

181 

182 # check is in mesolve 

183 # if len(c_ops) == 0 and not is_dm_data(rho0): 

184 # logging.warning( 

185 # "Consider using `jqt.sesolve()` instead, as `c_ops` is an empty list and the initial state is not a density matrix." 

186 # ) 

187 

188 ρ0 = rho0 + 0.0j 

189 

190 if len(c_ops) == 0: 

191 test_data = H(0.0) @ ρ0 

192 else: 

193 test_data = c_ops[0] @ H(0.0) @ ρ0 

194 

195 ρ0 = jnp.resize(ρ0, test_data.shape) # ensure correct shape 

196 

197 if len(c_ops) != 0: 

198 c_ops_bdims = c_ops.shape[:-2] 

199 c_ops = c_ops.reshape(*c_ops_bdims, c_ops.shape[-2], c_ops.shape[-1]) 

200 

201 def f( 

202 t: float, 

203 rho: Array, 

204 c_ops_val: Array, 

205 ): 

206 H_val = H(t) # type: ignore 

207 H_val = H_val + 0.0j 

208 

209 rho_dot = -1j * (H_val @ rho - rho @ H_val) 

210 

211 if len(c_ops_val) == 0: 

212 return rho_dot 

213 

214 c_ops_val_dag = dag_data(c_ops_val) 

215 

216 rho_dot_delta = 0.5 * ( 

217 2 * c_ops_val @ rho @ c_ops_val_dag 

218 - rho @ c_ops_val_dag @ c_ops_val 

219 - c_ops_val_dag @ c_ops_val @ rho 

220 ) 

221 

222 rho_dot_delta = jnp.sum(rho_dot_delta, axis=0) 

223 

224 rho_dot += rho_dot_delta 

225 

226 return rho_dot 

227 

228 sol = solve(f, ρ0, tlist, saveat_tlist, c_ops, 

229 solver_options=solver_options) 

230 

231 return sol.ys 

232 

233 

234def sesolve( 

235 H: Union[Qarray, Callable[[float], Qarray]], 

236 rho0: Qarray, 

237 tlist: Array, 

238 saveat_tlist: Optional[Array] = None, 

239 solver_options: Optional[SolverOptions] = None, 

240) -> Qarray: 

241 """Schrödinger Equation solver. 

242 

243 Args: 

244 H: time dependent Hamiltonian function or time-independent Qarray. 

245 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

246 tlist: time list 

247 saveat_tlist: list of times at which to save the state. If None, save at all times in tlist. Default: None. 

248 solver_options: SolverOptions with solver options 

249 

250 Returns: 

251 list of states 

252 """ 

253 

254 saveat_tlist = saveat_tlist if saveat_tlist is not None else tlist 

255 

256 saveat_tlist = jnp.atleast_1d(saveat_tlist) 

257 

258 ψ = rho0 

259 

260 if ψ.qtype == Qtypes.oper: 

261 raise ValueError( 

262 "Please use `jqt.mesolve` for initial state inputs in density matrix form." 

263 ) 

264 

265 ψ = ψ.to_ket() 

266 dims = ψ.dims 

267 ψ = ψ.data 

268 

269 if isinstance(H, Qarray): 

270 Ht_data = lambda t: H.data 

271 else: 

272 Ht_data = lambda t: H(t).data if H is not None else None 

273 

274 ys = _sesolve_data(Ht_data, ψ, tlist, saveat_tlist, 

275 solver_options=solver_options) 

276 

277 return jnp2jqt(ys, dims=dims) 

278 

279 

280def _sesolve_data( 

281 H: Callable[[float], Array], 

282 rho0: Array, 

283 tlist: Array, 

284 saveat_tlist: Array, 

285 solver_options: Optional[SolverOptions] = None, 

286): 

287 """Schrödinger Equation solver. 

288 

289 Args: 

290 H: time dependent Hamiltonian function or time-independent Array. 

291 rho0: initial state, must be a density matrix. For statevector evolution, please use sesolve. 

292 tlist: time list 

293 solver_options: SolverOptions with solver options 

294 

295 Returns: 

296 list of states 

297 """ 

298 

299 ψ = rho0 

300 ψ = ψ + 0.0j 

301 

302 def f(t: float, ψₜ: Array, _): 

303 H_val = H(t) # type: ignore 

304 H_val = H_val + 0.0j 

305 

306 ψₜ_dot = -1j * (H_val @ ψₜ) 

307 

308 return ψₜ_dot 

309 

310 ψ_test = f(0, ψ, None) 

311 ψ = jnp.resize(ψ, ψ_test.shape) # ensure correct shape 

312 

313 sol = solve(f, ψ, tlist, saveat_tlist, None, solver_options=solver_options) 

314 return sol.ys 

315 

316# ---- 

317 

318# propagators 

319# ---- 

320 

321def propagator( 

322 H: Union[Qarray, Callable[[float], Qarray]], 

323 ts: Union[float, Array], 

324 solver_options=None 

325): 

326 """ Generate the propagator for a time dependent Hamiltonian. 

327 

328 Args: 

329 H (Qarray or callable): 

330 A Qarray static Hamiltonian OR 

331 a function that takes a time argument and returns a Hamiltonian. 

332 ts (float or Array): 

333 A single time point or 

334 an Array of time points. 

335 

336 Returns: 

337 Qarray or List[Qarray]: 

338 The propagator for the Hamiltonian at time t. 

339 OR a list of propagators for the Hamiltonian at each time in t. 

340 

341 """ 

342 

343 

344 ts_is_scalar = robust_isscalar(ts) 

345 H_is_qarray = isinstance(H, Qarray) 

346 

347 if H_is_qarray: 

348 return (-1j * H * ts).expm() 

349 else: 

350 

351 if ts_is_scalar: 

352 H_first = H(0.0) 

353 if ts == 0: 

354 return identity_like(H_first) 

355 ts = jnp.array([0.0, ts]) 

356 else: 

357 H_first = H(ts[0]) 

358 

359 basis_states = multi_mode_basis_set(H_first.space_dims) 

360 results = sesolve(H, basis_states, ts) 

361 propagators_data = results.data.squeeze(-1) 

362 propagators = Qarray.create(propagators_data, dims=H_first.space_dims) 

363 

364 return propagators