# !pip install git+https://github.com/EQuS/jaxquantum.git # Comment this out if running in colab to install jaxquantum.
import jaxquantum as jqt
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
Qarray¶
The Qarray is the fundamental building block of jaxquantum. It is heavily inspired by QuTiP's Qobj and built to be compatible with JAX patterns.
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, 2.0) @ state
displaced_state.to_dm().header
'Quantum array: dims = ((50,), (50,)), bdims = (), shape = (50, 50), type = oper, impl = dense'
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state, pts)
(array([[<Axes: xlabel='Re[$\\alpha$]', ylabel='Im[$\\alpha$]'>]],
dtype=object),
<matplotlib.contour.QuadContourSet at 0x30f9a8ac0>)
Batching¶
A crucial difference between a Qarray and Qobj is its bdim or batch dimension. This enables us to seamlessly use batching and numpy broadcasting in calculations using Qarray objects.
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, jnp.array([[0.0, 0.5, 1.0],[1.5,2.0,2.5]])) @ state
displaced_state.header
'Quantum array: dims = ((50,), (1,)), bdims = (2, 3), shape = (2, 3, 50, 1), type = ket, impl = dense'
alphas_2d = jnp.array([[0.0, 0.5, 1.0], [1.5, 2.0, 2.5]])
xvec = jnp.linspace(-4, 4, 61)
jqt.plot_wigner(
displaced_state, xvec, xvec,
subtitles=np.array([[f"α = {float(a):.1f}" for a in row] for row in alphas_2d]),
figtitle="Displaced vacuum states",
)
plt.show()
Constructing a batched Qarray manually¶
N = 50
a = jqt.displace(N, 0.0)
b = jqt.displace(N, 1.0)
c = jqt.displace(N, 2.0)
arr1 = jqt.Qarray.from_array([[a,b,c],[a,b,c]])
arr2 = jqt.displace(N, jnp.array([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]]))
jnp.max(jnp.abs(arr1[0][1].data-arr2[0][1].data))
Array(6.10622664e-16, dtype=float64)
N = 20
def mean_photon_number(alpha):
"""Mean photon number of a displaced vacuum state |alpha>."""
state = jqt.displace(N, alpha) @ jqt.basis(N, 0)
return jnp.real(jqt.tr(jqt.num(N) @ state.to_dm()))
alphas = jnp.linspace(0.0, 3.0, 50)
n_vals = jax.vmap(mean_photon_number)(alphas)
plt.plot(alphas, n_vals, label=r"$\langle\hat{n}\rangle$ (vmap)")
plt.plot(alphas, alphas**2, "--", label=r"$|\alpha|^2$ (analytic)")
plt.xlabel(r"$\alpha$")
plt.ylabel(r"$\langle\hat{n}\rangle$")
plt.legend()
plt.title("Photon number of displaced vacuum")
plt.show()
jit¶
!!! warning "jaxquantum functions are not JIT-compiled by default."
Every call to jqt.mesolve, jqt.sesolve, jqt.displace, etc. runs in eager mode unless you explicitly wrap it with jax.jit. Wrap your own functions with @jax.jit whenever you call them more than once — the first call pays the compilation cost, and all subsequent calls run at full XLA speed.
Use jax.jit to JIT-compile quantum computations. The first call triggers compilation; subsequent calls run at full speed.
@jax.jit
def mean_photon_number_jit(alpha):
state = jqt.displace(N, alpha) @ jqt.basis(N, 0)
return jnp.real(jqt.tr(jqt.num(N) @ state.to_dm()))
alpha_val = jnp.array(2.0)
result = mean_photon_number_jit(alpha_val)
print(f"<n> = {result:.4f} (expected {float(alpha_val)**2:.4f})")
<n> = 4.0000 (expected 4.0000)
import time
N_sim = 15
tlist = jnp.linspace(0.0, 20.0, 200)
@jax.jit
def decaying_oscillator(kappa):
H = 1.0 * jqt.num(N_sim)
rho0 = jqt.ket2dm(jqt.basis(N_sim, 4))
c_ops = jqt.Qarray.from_list([jnp.sqrt(kappa) * jqt.destroy(N_sim)])
return jqt.mesolve(
H, rho0, tlist, c_ops=c_ops,
solver_options=jqt.SolverOptions.create(progress_meter=False),
)
n_op = jqt.num(N_sim)
kappas = [0.1, 0.3, 0.5]
for i, kappa in enumerate(kappas):
t0 = time.perf_counter()
result = decaying_oscillator(jnp.array(kappa))
jax.block_until_ready(result.data)
elapsed = time.perf_counter() - t0
label = "compile + run" if i == 0 else "run"
print(f"κ = {kappa} ({label}): {elapsed*1000:.0f} ms")
n_t = jnp.real(jqt.tr(n_op @ result))
plt.plot(tlist, n_t, label=rf"$\kappa = {kappa}$")
plt.xlabel("Time")
plt.ylabel(r"$\langle\hat{n}\rangle$")
plt.legend()
plt.title("JIT-compiled mesolve: decaying oscillator")
plt.show()
κ = 0.1 (compile + run): 1074 ms κ = 0.3 (run): 8 ms κ = 0.5 (run): 9 ms
grad¶
Use jax.grad to differentiate through quantum computations. For a displaced vacuum $|\alpha\rangle$, the mean photon number is $\langle\hat{n}\rangle = |\alpha|^2$, so $\frac{d\langle\hat{n}\rangle}{d\alpha} = 2\alpha$.
grad_fn = jax.grad(mean_photon_number_jit)
alpha_val = jnp.array(1.5)
print(f"<n> at alpha={float(alpha_val)}: {mean_photon_number_jit(alpha_val):.4f}")
print(f"d<n>/dalpha at alpha={float(alpha_val)}: {grad_fn(alpha_val):.4f}")
print(f"Analytic (2*alpha) : {2*float(alpha_val):.4f}")
<n> at alpha=1.5: 2.2500 d<n>/dalpha at alpha=1.5: 3.0000 Analytic (2*alpha) : 3.0000
Implementation Backends¶
By default, Qarrays use a dense matrix representation. For large Hilbert spaces, jaxquantum supports two sparse backends via the implementation argument:
| Backend | implementation= |
Best for |
|---|---|---|
| Dense | "dense" (default) |
Small spaces, general use |
| SparseDIA | "sparse_dia" |
Ladder operators ($\hat{a}$, $\hat{n}$) |
| BCOO | "sparse_bcoo" |
Arbitrary sparse operators |
All three backends share the same Qarray API — switching is a single keyword argument. For a detailed comparison with memory and timing benchmarks, see the Sparse Backends tutorial.
N = 20
a_dense = jqt.destroy(N)
a_dia = jqt.destroy(N, implementation="sparse_dia")
a_bcoo = jqt.destroy(N, implementation="sparse_bcoo")
print(a_dense.impl_type)
print(a_dia.impl_type)
print(a_bcoo.impl_type)
# The same operation works identically across all backends
state = jqt.basis(N, 3)
for op, name in [(a_dense, "dense"), (a_dia, "sparse_dia"), (a_bcoo, "sparse_bcoo")]:
n_expect = jnp.real(jqt.tr(op.dag() @ op @ state.to_dm()))
print(f"<n> via {name}: {n_expect:.1f}") # expected: 3.0
QarrayImplType.DENSE QarrayImplType.SPARSE_DIA QarrayImplType.SPARSE_BCOO <n> via dense: 3.0 <n> via sparse_dia: 3.0 <n> via sparse_bcoo: 3.0