# !pip install git+https://github.com/EQuS/jaxquantum.git # Comment this out if running in colab to install jaxquantum.
import jaxquantum as jqt
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
/opt/miniconda3/envs/jqt-env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Qarray¶
The Qarray is the fundamental building block of jaxquantum. It is heavily inspired by QuTiP's Qobj and built to be compatible with JAX patterns.
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, 2.0) @ state
displaced_state.to_dm().header
'Quantum array: dims = ((50,), (50,)), bdims = (), shape = (50, 50), type = oper, impl = dense'
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state, pts)
(array([[<Axes: xlabel='Re[$\\alpha$]', ylabel='Im[$\\alpha$]'>]],
dtype=object),
<matplotlib.contour.QuadContourSet at 0x162221e90>)
Batching¶
A crucial difference between a Qarray and Qobj is its bdim or batch dimension. This enables us to seamlessly use batching and numpy broadcasting in calculations using Qarray objects.
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, jnp.array([[0.0, 0.5, 1.0],[1.5,2.0,2.5]])) @ state
displaced_state.header
'Quantum array: dims = ((50,), (1,)), bdims = (2, 3), shape = (2, 3, 50, 1), type = ket, impl = dense'
alphas_2d = jnp.array([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]])
xvec = jnp.linspace(-4, 4, 61)
jqt.plot_wigner(
displaced_state, xvec, xvec,
subtitles=np.array([[f"α = {float(a):.1f}" for a in row] for row in alphas_2d]),
figtitle="Displaced vacuum states",
)
plt.show()
Constructing a batched Qarray manually¶
N = 50
a = jqt.displace(N, 0.0)
b = jqt.displace(N, 1.0)
c = jqt.displace(N, 2.0)
arr1 = jqt.Qarray.from_array([[a,b,c],[a,b,c]])
arr2 = jqt.displace(N, jnp.array([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]]))
jnp.max(jnp.abs(arr1[0][1].data-arr2[0][1].data))
Array(0., dtype=float64)
N = 20
def mean_photon_number(alpha):
"""Mean photon number of a displaced vacuum state |alpha>."""
state = jqt.displace(N, alpha) @ jqt.basis(N, 0)
return jnp.real(jqt.tr(jqt.num(N) @ state.to_dm()))
alphas = jnp.linspace(0.0, 3.0, 50)
n_vals = jax.vmap(mean_photon_number)(alphas)
plt.plot(alphas, n_vals, label=r"$\langle\hat{n}\rangle$ (vmap)")
plt.plot(alphas, alphas**2, "--", label=r"$|\alpha|^2$ (analytic)")
plt.xlabel(r"$\alpha$")
plt.ylabel(r"$\langle\hat{n}\rangle$")
plt.legend()
plt.title("Photon number of displaced vacuum")
plt.show()
jit¶
!!! warning "jaxquantum functions are not JIT-compiled by default."
Every call to jqt.mesolve, jqt.sesolve, jqt.displace, etc. runs in eager mode unless you explicitly wrap it with jax.jit. Wrap your own functions with @jax.jit whenever you call them more than once — the first call pays the compilation cost, and all subsequent calls run at full XLA speed.
Use jax.jit to JIT-compile quantum computations. The first call triggers compilation; subsequent calls run at full speed.
@jax.jit
def mean_photon_number_jit(alpha):
state = jqt.displace(N, alpha) @ jqt.basis(N, 0)
return jnp.real(jqt.tr(jqt.num(N) @ state.to_dm()))
alpha_val = jnp.array(2.0)
result = mean_photon_number_jit(alpha_val)
print(f"<n> = {result:.4f} (expected {float(alpha_val)**2:.4f})")
<n> = 4.0000 (expected 4.0000)
import time
N_sim = 15
tlist = jnp.linspace(0.0, 20.0, 200)
@jax.jit
def decaying_oscillator(kappa):
H = 1.0 * jqt.num(N_sim)
rho0 = (jqt.basis(N_sim, 1) + jqt.basis(N_sim, 0)).unit().to_dm()
c_ops = jqt.Qarray.from_list([jnp.sqrt(kappa) * jqt.destroy(N_sim)])
return jqt.mesolve(
H, rho0, tlist, c_ops=c_ops,
solver_options=jqt.SolverOptions.create(progress_meter=False),
)
n_op = jqt.num(N_sim)
kappas = [0.1, 0.3, 0.5]
for i, kappa in enumerate(kappas):
t0 = time.perf_counter()
result = decaying_oscillator(jnp.array(kappa))
jax.block_until_ready(result.data)
elapsed = time.perf_counter() - t0
label = "compile + run" if i == 0 else "run"
print(f"κ = {kappa} ({label}): {elapsed*1000:.0f} ms")
n_t = jnp.real(jqt.tr(n_op @ result))
plt.plot(tlist, n_t, label=rf"$\kappa = {kappa}$")
plt.xlabel("Time")
plt.ylabel(r"$\langle\hat{n}\rangle$")
plt.legend()
plt.title("JIT-compiled mesolve: decaying oscillator")
plt.show()
κ = 0.1 (compile + run): 1049 ms κ = 0.3 (run): 8 ms κ = 0.5 (run): 7 ms
Animating batched simulations¶
With gif=True, plot_wigner (and plot_qp / plot_qfunc) can render a batched state as an animation instead of a tiled subplot grid. Below we vmap decaying_oscillator over several loss rates and animate the Wigner function in time. The first batch axis becomes the per-frame subplot row (one column per κ); axis 1 (time) is selected as the animation axis via gif_params.
kappas_arr = jnp.array([0.1, 0.3, 0.5])
# vmap returns a Qarray whose underlying data has a leading kappa axis;
# slicing reifies it into the batch dims as (kappa, time).
result = jax.vmap(decaying_oscillator)(kappas_arr)
states_sub = result[:, ::8] # subsample time for a snappier gif
ts_sub = tlist[::8]
print('Animated states bdims:', states_sub.bdims)
xvec = jnp.linspace(-4, 4, 41)
anim = jqt.plot_wigner(
states_sub, xvec,
subtitles=np.array([f"κ = {float(k):.2f}" for k in kappas_arr]),
figtitle="Decaying oscillator Wigner",
gif=True,
gif_params={
"ts": ts_sub,
"interval_ms": 80,
"batch_animation_axis": 1, # time axis
# "save_path": "decaying_oscillator.gif",
},
)
anim
Animated states bdims: (3, 25)
The same gif=True / gif_params interface works on plot_cf_wigner (and plot_cf), animating the Wigner characteristic function in the same way. Each frame now shows real and imaginary parts side by side per column.
# Reuse states_sub / ts_sub / kappas_arr from the previous cell.
# Use a finer xvec range so the CF features are visible.
cf_xvec = jnp.linspace(-3, 3, 41)
anim_cf = jqt.plot_cf_wigner(
states_sub, cf_xvec,
subtitles=np.array([f"κ = {float(k):.2f}" for k in kappas_arr]),
figtitle="Decaying oscillator CF (Wigner)",
gif=True,
gif_params={
"ts": ts_sub,
"interval_ms": 80,
"batch_animation_axis": 1, # time axis
# "save_path": "decaying_oscillator_cf.gif",
},
)
anim_cf
grad¶
Use jax.grad to differentiate through quantum computations. For a displaced vacuum $|\alpha\rangle$, the mean photon number is $\langle\hat{n}\rangle = |\alpha|^2$, so $\frac{d\langle\hat{n}\rangle}{d\alpha} = 2\alpha$.
grad_fn = jax.grad(mean_photon_number_jit)
alpha_val = jnp.array(1.5)
print(f"<n> at alpha={float(alpha_val)}: {mean_photon_number_jit(alpha_val):.4f}")
print(f"d<n>/dalpha at alpha={float(alpha_val)}: {grad_fn(alpha_val):.4f}")
print(f"Analytic (2*alpha) : {2*float(alpha_val):.4f}")
<n> at alpha=1.5: 2.2500 d<n>/dalpha at alpha=1.5: 3.0000 Analytic (2*alpha) : 3.0000
Implementation Backends¶
By default, Qarrays use a dense matrix representation. For large Hilbert spaces, jaxquantum supports two sparse backends via the implementation argument:
| Backend | implementation= |
Best for |
|---|---|---|
| Dense | "dense" (default) |
Small spaces, general use |
| SparseDIA | "sparse_dia" |
Ladder operators ($\hat{a}$, $\hat{n}$) |
| BCOO | "sparse_bcoo" |
Arbitrary sparse operators |
All three backends share the same Qarray API — switching is a single keyword argument. For a detailed comparison with memory and timing benchmarks, see the Sparse Backends tutorial.
N = 20
a_dense = jqt.destroy(N)
a_dia = jqt.destroy(N, implementation="sparse_dia")
a_bcoo = jqt.destroy(N, implementation="sparse_bcoo")
print(a_dense.impl_type)
print(a_dia.impl_type)
print(a_bcoo.impl_type)
# The same operation works identically across all backends
state = jqt.basis(N, 3)
for op, name in [(a_dense, "dense"), (a_dia, "sparse_dia"), (a_bcoo, "sparse_bcoo")]:
n_expect = jnp.real(jqt.tr(op.dag() @ op @ state.to_dm()))
print(f"<n> via {name}: {n_expect:.1f}") # expected: 3.0
QarrayImplType.DENSE QarrayImplType.SPARSE_DIA QarrayImplType.SPARSE_BCOO <n> via dense: 3.0 <n> via sparse_dia: 3.0 <n> via sparse_bcoo: 3.0