Skip to content

conversions

Converting between different object types.

extract_dims(arr, dims=None)

Extract dims from a JAX array or Qarray.

Parameters:

Name Type Description Default
arr Array

JAX array or Qarray.

required
dims Optional[Union[DIMS_TYPE, List[int]]]

Qarray dims.

None

Returns:

Type Description

Qarray dims.

Source code in jaxquantum/core/conversions.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def extract_dims(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
    """Extract dims from a JAX array or Qarray.

    Args:
        arr: JAX array or Qarray.
        dims: Qarray dims.

    Returns:
        Qarray dims.
    """
    if isinstance(dims[0], Number):
        is_op = arr.shape[-2] == arr.shape[-1]
        if is_op:
            dims = [dims, dims]
        else:
            dims = [dims, [1] * len(dims)]  # defaults to ket
    return dims

jnp2jqt(arr, dims=None)

JAX array -> QuTiP state.

Parameters:

Name Type Description Default
jnp_obj

JAX array.

required
dims Optional[Union[DIMS_TYPE, List[int]]]

Qarray dims.

None

Returns:

Type Description

QuTiP state.

Source code in jaxquantum/core/conversions.py
79
80
81
82
83
84
85
86
87
88
89
90
def jnp2jqt(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
    """JAX array -> QuTiP state.

    Args:
        jnp_obj: JAX array.
        dims: Qarray dims.

    Returns:
        QuTiP state.
    """
    dims = extract_dims(arr, dims) if dims is not None else None
    return Qarray.create(arr, dims=dims)

jqt2qt(jqt_obj)

Qarray -> QuTiP state.

Parameters:

Name Type Description Default
jqt_obj

Qarray.

required
dims

QuTiP dims.

required

Returns:

Type Description

QuTiP state.

Source code in jaxquantum/core/conversions.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def jqt2qt(jqt_obj):
    """Qarray -> QuTiP state.

    Args:
        jqt_obj: Qarray.
        dims: QuTiP dims.

    Returns:
        QuTiP state.
    """
    if isinstance(jqt_obj, Qobj) or jqt_obj is None:
        return jqt_obj

    if jqt_obj.is_batched:
        res = []
        for i in range(len(jqt_obj)):
            res.append(jqt2qt(jqt_obj[i]))
        return res

    dims = [list(jqt_obj.dims[0]), list(jqt_obj.dims[1])]
    return Qobj(np.array(jqt_obj.data), dims=dims)

qt2jqt(qt_obj, dtype=jnp.complex128)

QuTiP state -> Qarray.

Parameters:

Name Type Description Default
qt_obj

QuTiP state.

required
dtype

JAX dtype.

complex128

Returns:

Type Description

Qarray.

Source code in jaxquantum/core/conversions.py
22
23
24
25
26
27
28
29
30
31
32
33
34
def qt2jqt(qt_obj, dtype=jnp.complex128):
    """QuTiP state -> Qarray.

    Args:
        qt_obj: QuTiP state.
        dtype: JAX dtype.

    Returns:
        Qarray.
    """
    if isinstance(qt_obj, Qarray) or qt_obj is None:
        return qt_obj
    return Qarray.create(jnp.array(qt_obj.full(), dtype=dtype), dims=qt_obj.dims)