S. R. Jha, S. Chowdhury, G. Rolleri, M. Hays, J. A. Grover, W. D. Oliver
Docs: jaxquantum.org | Discord: discord.gg/frWqbjvZ4s
jaxquantum is a unified JAX-native toolkit for quantum hardware design, simulation, and control — auto-differentiable and accelerated on CPU, GPU, and TPU. It serves as a QuTiP drop-in replacement and absorbs the prior bosonic and qcsys projects.
Highlights
- Superconducting devices — ready-to-use Transmon, Fluxonium, and Resonator models with eigenspectrum, wavefunctions, and parameter sweeps. See the devices tutorial.
- Bosonic codes — Cat, GKP, and Binomial qubit encodings with logical gates and phase-space visualization. See the bosonic codes tutorial.
- Gate-based circuits — hierarchical circuits with unitary, Hamiltonian, and Kraus simulation modes; gradient-based gate optimization. See the circuits tutorial.
- Sparse backends —
SparseDIAandBCOOstorage for large Hilbert spaces with the same API as dense. See the sparse backends tutorial. - First-class JAX — use
jax.vmapfor parameter sweeps,jax.jitfor compiled simulation, andjax.gradfor differentiable physics out of the box.
Installation
pip install jaxquantum
For GPU (NVIDIA, CUDA13) or TPU, use the [gpu] or [tpu] extras. For the latest development version, install directly from source:
pip install git+https://github.com/EQuS/jaxquantum.git
For development (editable + dev/docs extras): pip install -e ".[dev,docs]". See the installation guide for full details, hardware checks, and troubleshooting.
Quick Start
from jax import jit
import jaxquantum as jqt
import jax.numpy as jnp
import matplotlib.pyplot as plt
N = 100
omega_a = 2.0*jnp.pi*5.0
kappa = 2*jnp.pi*jnp.array([1,2]) # Batching to explore two different kappa values!
initial_state = jqt.displace(N, 0.1) @ jqt.basis(N,0)
initial_state_dm = initial_state.to_dm()
ts = jnp.linspace(0, 4*2*jnp.pi/omega_a, 101)
a = jqt.destroy(N)
n = a.dag() @ a
c_ops = jqt.Qarray.from_list([jnp.sqrt(kappa)*a])
@jit
def Ht(t):
H0 = omega_a*n
return H0
solver_options = jqt.SolverOptions.create(progress_meter=True)
states = jqt.mesolve(Ht, initial_state_dm, ts, c_ops=c_ops, solver_options=solver_options)
nt = jnp.real(jqt.overlap(n, states))
a_real = jnp.real(jqt.overlap(a, states))
a_imag = jnp.imag(jqt.overlap(a, states))
fig, axs = plt.subplots(2,1, dpi=200, figsize=(6,5))
ax = axs[0]
ax.plot(ts, a_real[:,0], label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 0
ax.plot(ts, a_real[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="blue") # Batch kappa value 1
ax.plot(ts, a_imag[:,0], label=r"$Re[\langle a(t)\rangle]$", color="red") # Batch kappa value 0
ax.plot(ts, a_imag[:,1], "--", label=r"$Re[\langle a(t)\rangle]$", color="red") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
ax = axs[1]
ax.plot(ts, nt[:,0], label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 0
ax.plot(ts, nt[:,1], "--", label=r"$Re[\langle n(t)\rangle]$", color="green") # Batch kappa value 1
ax.set_xlabel("Time (ns)")
ax.set_ylabel("Expectations")
ax.legend()
fig.tight_layout()
Acknowledgements & History
Core Devs: Shantanu R. Jha, Shoumik Chowdhury, Gabriele Rolleri
This package was initially a small part of bosonic. In early 2022, jaxquantum was extracted and made into its own package. This package was briefly announced to the world at APS March Meeting 2023 and released to a select few academic groups shortly after. Since then, this package has been open sourced and developed while conducting research in the Engineering Quantum Systems Group at MIT with advice and support from Prof. William D. Oliver.
Citation
Thank you for taking the time to try our package out. If you found it useful in your research, please cite us as follows:
@software{jha2024jaxquantum,
author = {Shantanu R. Jha and Shoumik Chowdhury and Gabriele Rolleri and Max Hays and Jeff A. Grover and William D. Oliver},
title = {JAXQuantum: An auto-differentiable and hardware-accelerated toolkit for quantum hardware design, simulation, and control},
url = {https://jaxquantum.org},
version = {0.3.0},
year = {2024},
}
S. R. Jha, S. Chowdhury, G. Rolleri, M. Hays, J. A. Grover, and W. D. Oliver. "JAXQuantum: An auto-differentiable and hardware-accelerated toolkit for quantum hardware design, simulation, and control," jaxquantum.org (2025).
Contributions & Contact
This package is open source and, as such, very open to contributions. Please don't hesitate to open an issue, report a bug, request a feature, or create a pull request. We are also open to deeper collaborations to create a tool that is more useful for everyone. If a discussion would be helpful, please email shanjha@mit.edu to set up a meeting.