Coverage for jaxquantum/core/conversions.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Converting between different object types. 

3""" 

4 

5from numbers import Number 

6from jax import config, Array 

7from qutip import Qobj 

8from typing import Optional, Union, List 

9import jax.numpy as jnp 

10import numpy as np 

11 

12 

13from jaxquantum.core.qarray import Qarray 

14from jaxquantum.core.dims import DIMS_TYPE 

15 

16 

17config.update("jax_enable_x64", True) 

18 

19 

20# Convert between QuTiP and JAX 

21# =============================================================== 

22def qt2jqt(qt_obj, dtype=jnp.complex128): 

23 """QuTiP state -> Qarray. 

24 

25 Args: 

26 qt_obj: QuTiP state. 

27 dtype: JAX dtype. 

28 

29 Returns: 

30 Qarray. 

31 """ 

32 if isinstance(qt_obj, Qarray) or qt_obj is None: 

33 return qt_obj 

34 return Qarray.create(jnp.array(qt_obj.full(), dtype=dtype), dims=qt_obj.dims) 

35 

36 

37def jqt2qt(jqt_obj): 

38 """Qarray -> QuTiP state. 

39 

40 Args: 

41 jqt_obj: Qarray. 

42 dims: QuTiP dims. 

43 

44 Returns: 

45 QuTiP state. 

46 """ 

47 if isinstance(jqt_obj, Qobj) or jqt_obj is None: 

48 return jqt_obj 

49 

50 if jqt_obj.is_batched: 

51 res = [] 

52 for i in range(len(jqt_obj)): 

53 res.append(jqt2qt(jqt_obj[i])) 

54 return res 

55 

56 dims = [list(jqt_obj.dims[0]), list(jqt_obj.dims[1])] 

57 return Qobj(np.array(jqt_obj.data), dims=dims) 

58 

59 

60def extract_dims(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None): 

61 """Extract dims from a JAX array or Qarray. 

62 

63 Args: 

64 arr: JAX array or Qarray. 

65 dims: Qarray dims. 

66 

67 Returns: 

68 Qarray dims. 

69 """ 

70 if isinstance(dims[0], Number): 

71 is_op = arr.shape[-2] == arr.shape[-1] 

72 if is_op: 

73 dims = [dims, dims] 

74 else: 

75 dims = [dims, [1] * len(dims)] # defaults to ket 

76 return dims 

77 

78 

79def jnp2jqt(arr: Array, dims: Optional[Union[DIMS_TYPE, List[int]]] = None): 

80 """JAX array -> QuTiP state. 

81 

82 Args: 

83 jnp_obj: JAX array. 

84 dims: Qarray dims. 

85 

86 Returns: 

87 QuTiP state. 

88 """ 

89 dims = extract_dims(arr, dims) if dims is not None else None 

90 return Qarray.create(arr, dims=dims)