Coverage for jaxquantum/core/conversions.py: 100%
33 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Converting between different object types.
3"""
5from numbers import Number
6from jax import config, Array
7from qutip import Qobj
8from typing import Optional, Union, List
9import jax.numpy as jnp
10import numpy as np
13from jaxquantum.core.qarray import Qarray
14from jaxquantum.core.dims import DIMS_TYPE
17config.update("jax_enable_x64", True)
20# Convert between QuTiP and JAX
21# ===============================================================
22def qt2jqt(qt_obj, dtype=jnp.complex128):
23 """QuTiP state -> Qarray.
25 Args:
26 qt_obj: QuTiP state.
27 dtype: JAX dtype.
29 Returns:
30 Qarray.
31 """
32 if isinstance(qt_obj, Qarray) or qt_obj is None:
33 return qt_obj
34 return Qarray.create(jnp.array(qt_obj.full(), dtype=dtype), dims=qt_obj.dims)
37def jqt2qt(jqt_obj):
38 """Qarray -> QuTiP state.
40 Args:
41 jqt_obj: Qarray.
42 dims: QuTiP dims.
44 Returns:
45 QuTiP state.
46 """
47 if isinstance(jqt_obj, Qobj) or jqt_obj is None:
48 return jqt_obj
50 if jqt_obj.is_batched:
51 res = []
52 for i in range(len(jqt_obj)):
53 res.append(jqt2qt(jqt_obj[i]))
54 return res
56 dims = [list(jqt_obj.dims[0]), list(jqt_obj.dims[1])]
57 return Qobj(np.array(jqt_obj.data), dims=dims)
60def extract_dims(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
61 """Extract dims from a JAX array or Qarray.
63 Args:
64 arr: JAX array or Qarray.
65 dims: Qarray dims.
67 Returns:
68 Qarray dims.
69 """
70 if isinstance(dims[0], Number):
71 is_op = arr.shape[-2] == arr.shape[-1]
72 if is_op:
73 dims = [dims, dims]
74 else:
75 dims = [dims, [1] * len(dims)] # defaults to ket
76 return dims
79def jnp2jqt(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None):
80 """JAX array -> QuTiP state.
82 Args:
83 jnp_obj: JAX array.
84 dims: Qarray dims.
86 Returns:
87 QuTiP state.
88 """
89 dims = extract_dims(arr, dims) if dims is not None else None
90 return Qarray.create(arr, dims=dims)