Skip to content

Welcome to jaxquantum

License code coverage

S. R. Jha, S. Chowdhury, G. Rolleri, M. Hays, J. A. Grover, W. D. Oliver

Docs: equs.github.io/jaxquantum

Community Discord: discord.gg/frWqbjvZ4s

jaxquantum leverages JAX to enable the auto differentiable and (CPU, GPU, TPU) accelerated simulation of quantum dynamical systems, including tooling such as operator construction, unitary evolution and master equation solving. As such, jaxquantum serves as a QuTiP drop-in replacement written entirely in JAX.

Moreover, jaxquantum has recently absorbed bosonic and qcsys. As such, it is now a unified toolkit for quantum circuit design, simulation and control.

Installation

Recommended: As this is a rapidly evolving project, we recommend installing the latest version of jaxquantum from source as follows:

pip install git+https://github.com/EQuS/jaxquantum.git

If you are interested in contributing to the package, please clone this repository and install this package in editable mode after changing into the root directory of this repository:

pip install -e ".[dev,docs]"
This will also install extras from the dev and docs flags, which can be useful when developing the package. Since this is installed in editable mode, the package will automatically be updated after pulling new changes in the repository.

jaxquantum is also published on PyPI. Simply run the following code to install the package:

pip install jaxquantum

For more details, please visit the getting started > installation section of our docs.

An Example

Here's an example of how to set up a simulation in jaxquantum.

from jax import jit
import jaxquantum as jqt 
import jax.numpy as jnp
import matplotlib.pyplot as plt

N = 100

omega_a = 2.0*jnp.pi*5.0
kappa = 2*jnp.pi*jnp.array([1,2]) # Batching to explore two different kappa values!
initial_state = jqt.displace(N, 0.1) @ jqt.basis(N,0)
initial_state_dm = initial_state.to_dm()
ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101)

a = jqt.destroy(N)
n = a.dag() @ a

c_ops = jqt.Qarray.from_list([jnp.sqrt(kappa)*a])

@jit
def Ht(t):
    H0 = omega_a*n
    return H0

solver_options = jqt.SolverOptions.create(progress_meter=True)
states = jqt.mesolve(Ht, initial_state_dm, ts, c_ops=c_ops, solver_options=solver_options) 
nt = jnp.real(jqt.overlap(n, states))
a_real = jnp.real(jqt.overlap(a, states))
a_imag = jnp.imag(jqt.overlap(a, states))

fig, axs = plt.subplots(2,1, dpi=200, figsize=(6,5))
ax = axs[0]
ax.plot(ts, a_real[:,0], label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 0
ax.plot(ts, a_real[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 1
ax.plot(ts, a_imag[:,0], label=r"$Re[\langle a(t)\rangle]$", color="red") # Batch kappa value 0
ax.plot(ts, a_imag[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="red") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()

ax = axs[1]
ax.plot(ts, nt[:,0], label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 0
ax.plot(ts, nt[:,1], "--", label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
fig.tight_layout()
Output of above code.

Acknowledgements & History

Core Devs: Shantanu R. Jha, Shoumik Chowdhury, Gabriele Rolleri

This package was initially a small part of bosonic. In early 2022, jaxquantum was extracted and made into its own package. This package was briefly announced to the world at APS March Meeting 2023 and released to a select few academic groups shortly after. Since then, this package has been open sourced and developed while conducting research in the Engineering Quantum Systems Group at MIT with advice and support from Prof. William D. Oliver.