Coverage for jaxquantum/utils/utils.py: 91%

55 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1""" 

2JAX Utils 

3""" 

4 

5from numbers import Number 

6from typing import Dict 

7 

8from jax import lax, Array, device_put, config 

9from jax._src.scipy.special import gammaln 

10import jax.numpy as jnp 

11import numpy as np 

12 

13from typing import Literal 

14 

15config.update("jax_enable_x64", True) 

16 

17 

18def device_put_params(params: Dict, non_device_params=None): 

19 non_device_params = [] if non_device_params is None else non_device_params 

20 for param in params: 

21 if param in non_device_params: 

22 continue 

23 if isinstance(params[param], Number) or isinstance(params[param], np.ndarray): 

24 params[param] = device_put(params[param]) 

25 return params 

26 

27 

28def comb(N, k): 

29 """ 

30 NCk 

31 

32 #TODO: replace with jsp.special.comb once issue is closed: 

33 https://github.com/google/jax/issues/9709 

34 

35 Args: 

36 N: total items 

37 k: # of items to choose 

38 

39 Returns: 

40 NCk: N choose k 

41 """ 

42 one = 1 

43 N_plus_1 = lax.add(N, one) 

44 k_plus_1 = lax.add(k, one) 

45 return lax.exp( 

46 lax.sub( 

47 gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k))) 

48 ) 

49 ) 

50 

51 

52def complex_to_real_iso_matrix(A): 

53 return jnp.block([[jnp.real(A), -jnp.imag(A)], [jnp.imag(A), jnp.real(A)]]) 

54 

55 

56def real_to_complex_iso_matrix(A): 

57 N = A.shape[0] 

58 return A[: N // 2, : N // 2] + 1j * A[N // 2 :, : N // 2] 

59 

60 

61def complex_to_real_iso_vector(v): 

62 return jnp.block([[jnp.real(v)], [jnp.imag(v)]]) 

63 

64 

65def real_to_complex_iso_vector(v): 

66 N = v.shape[0] 

67 return v[: N // 2, :] + 1j * v[N // 2 :, :] 

68 

69 

70def imag_times_iso_vector(v): 

71 N = v.shape[0] 

72 return jnp.block([[-v[N // 2 :, :]], [v[: N // 2, :]]]) 

73 

74 

75def imag_times_iso_matrix(A): 

76 N = A.shape[0] 

77 Ar = A[: N // 2, : N // 2] 

78 Ai = A[N // 2 :, : N // 2] 

79 return jnp.block([[-Ai, -Ar], [Ar, -Ai]]) 

80 

81 

82def conj_transpose_iso_matrix(A): 

83 N = A.shape[0] 

84 Ar = A[: N // 2, : N // 2].T 

85 Ai = A[N // 2 :, : N // 2].T 

86 return jnp.block([[Ar, Ai], [-Ai, Ar]]) 

87 

88 

89def robust_isscalar(val): 

90 is_scalar = isinstance(val, Number) or jnp.isscalar(val) 

91 if isinstance(val, Array): 

92 is_scalar = is_scalar or (len(val.shape) == 0) 

93 return is_scalar 

94 

95 

96# ===================================================== 

97 

98# Precision 

99 

100def set_precision(precision: Literal["single", "double"]): 

101 """ 

102 Set the precision of JAX operations. 

103 

104 Args: 

105 precision: 'single' or 'double' 

106 

107 Raises: 

108 ValueError: if precision is not 'single' or 'double' 

109 """ 

110 if precision == "single": 

111 config.update("jax_enable_x64", False) 

112 elif precision == "double": 

113 config.update("jax_enable_x64", True) 

114 else: 

115 raise ValueError("precision must be 'single' or 'double'")