In [ ]:
Copied!
# !pip install git+https://github.com/EQuS/jaxquantum.git # Comment this out if running in colab to install jaxquantum.
# !pip install git+https://github.com/EQuS/jaxquantum.git # Comment this out if running in colab to install jaxquantum.
In [1]:
Copied!
import jaxquantum as jqt
import jax.numpy as jnp
import jaxquantum as jqt
import jax.numpy as jnp
Qarray¶
The Qarray is the fundamental building block of jaxquantum. It is heavily inspired by QuTiP's Qobj and built to be compatible with JAX patterns.
In [2]:
Copied!
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, 2.0) @ state
displaced_state.to_dm().header
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, 2.0) @ state
displaced_state.to_dm().header
Out[2]:
'Quantum array: dims = ((50,), (50,)), bdims = (), shape = (50, 50), type = oper'
In [3]:
Copied!
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state, pts)
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state, pts)
Out[3]:
(<Axes: xlabel='Re[$\\alpha$]', ylabel='Im[$\\alpha$]'>, <matplotlib.contour.QuadContourSet at 0x16ba610a0>)
Batching¶
A crucial difference between a Qarray and Qobj is its bdim or batch dimension. This enables us to seamlessly use batching and numpy broadcasting in calculations using Qarray objects.
In [4]:
Copied!
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, jnp.array([[0.0, 0.5, 1.0],[1.5,2.0,2.5]])) @ state
displaced_state.header
N = 50
state = jqt.basis(N, 0)
displaced_state = jqt.displace(N, jnp.array([[0.0, 0.5, 1.0],[1.5,2.0,2.5]])) @ state
displaced_state.header
Out[4]:
'Quantum array: dims = ((50,), (1,)), bdims = (2, 3), shape = (2, 3, 50, 1), type = ket'
In [5]:
Copied!
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state[0][0], pts)
jqt.plot_wigner(displaced_state[0][2], pts)
jqt.plot_wigner(displaced_state[1][1], pts)
pts = jnp.linspace(-4, 4, 100)
jqt.plot_wigner(displaced_state[0][0], pts)
jqt.plot_wigner(displaced_state[0][2], pts)
jqt.plot_wigner(displaced_state[1][1], pts)
Out[5]:
(<Axes: xlabel='Re[$\\alpha$]', ylabel='Im[$\\alpha$]'>, <matplotlib.contour.QuadContourSet at 0x30f62b340>)
Constructing a batched Qarray manually¶
In [6]:
Copied!
N = 50
a = jqt.displace(N, 0.0)
b = jqt.displace(N, 1.0)
c = jqt.displace(N, 2.0)
arr1 = jqt.Qarray.from_array([[a,b,c],[a,b,c]])
arr2 = jqt.displace(N, jnp.array([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]]))
jnp.max(jnp.abs(arr1[0][1].data-arr2[0][1].data))
N = 50
a = jqt.displace(N, 0.0)
b = jqt.displace(N, 1.0)
c = jqt.displace(N, 2.0)
arr1 = jqt.Qarray.from_array([[a,b,c],[a,b,c]])
arr2 = jqt.displace(N, jnp.array([[0.0, 1.0, 2.0],[0.0, 1.0, 2.0]]))
jnp.max(jnp.abs(arr1[0][1].data-arr2[0][1].data))
Out[6]:
Array(6.10622664e-16, dtype=float64)