In [1]:
Copied!
import jaxquantum as jqt
import jax.numpy as jnp
import jaxquantum as jqt
import jax.numpy as jnp
Qarray¶
The Qarray is the fundamental building block of jaxquantum
. It is heavily inspired by QuTiP's Qobj and built to be compatible with JAX
patterns.
In [2]:
Copied!
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, 2.0) @ state
displaced_state.to_dm().header
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, 2.0) @ state
displaced_state.to_dm().header
Out[2]:
'Quantum array: dims = ((50,), (50,)), bdims = (), shape = (50, 50), type = oper'
In [3]:
Copied!
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state, pts)
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state, pts)
Out[3]:
(<Axes: xlabel='Re[$\\alpha$]', ylabel='Im[$\\alpha$]'>, <matplotlib.contour.QuadContourSet at 0x7f99dc0c1400>)
Batching¶
A crucial difference between a Qarray and Qobj is its bdim
or batch dimension. This enables us to seamlessly use batching and numpy broadcasting in calculations using Qarray objects.
In [4]:
Copied!
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, jnp.array([[0.0, 0.5, 1.0],[1.5,2.0,2.5]])) @ state
displaced_state.header
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, jnp.array([[0.0, 0.5, 1.0],[1.5,2.0,2.5]])) @ state
displaced_state.header
Out[4]:
'Quantum array: dims = ((50,), (1,)), bdims = (2, 3), shape = (2, 3, 50, 1), type = ket'
In [5]:
Copied!
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state[0][0], pts)
jqt.plot_wigner(displaced_state[0][2], pts)
jqt.plot_wigner(displaced_state[1][1], pts)
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state[0][0], pts)
jqt.plot_wigner(displaced_state[0][2], pts)
jqt.plot_wigner(displaced_state[1][1], pts)
Out[5]:
(<Axes: xlabel='Re[$\\alpha$]', ylabel='Im[$\\alpha$]'>, <matplotlib.contour.QuadContourSet at 0x7f99ac95d810>)
Constructing a batched Qarray manually¶
In [6]:
Copied!
N = 50
a = jqt.displace(N, 0.0)
b = jqt.displace(N, 1.0)
c = jqt.displace(N, 2.0)
arr1 = jqt.Qarray.from_array([[a,b,c],[a,b,c]])
arr2 = jqt.displace(N, jnp.array([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]]))
jnp.max(jnp.abs(arr1[0][1].data-arr2[0][1].data))
N = 50
a = jqt.displace(N, 0.0)
b = jqt.displace(N, 1.0)
c = jqt.displace(N, 2.0)
arr1 = jqt.Qarray.from_array([[a,b,c],[a,b,c]])
arr2 = jqt.displace(N, jnp.array([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]]))
jnp.max(jnp.abs(arr1[0][1].data-arr2[0][1].data))
Out[6]:
Array(0., dtype=float64)