Coverage for jaxquantum/utils/utils.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2JAX Utils 

3""" 

4 

5from numbers import Number 

6from typing import Dict 

7 

8from jax import lax, Array, device_put, config 

9from jax._src.scipy.special import gammaln 

10import jax.numpy as jnp 

11import numpy as np 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16def device_put_params(params: Dict, non_device_params=None): 

17 non_device_params = [] if non_device_params is None else non_device_params 

18 for param in params: 

19 if param in non_device_params: 

20 continue 

21 if isinstance(params[param], Number) or isinstance(params[param], np.ndarray): 

22 params[param] = device_put(params[param]) 

23 return params 

24 

25 

26def comb(N, k): 

27 """ 

28 NCk 

29 

30 #TODO: replace with jsp.special.comb once issue is closed: 

31 https://github.com/google/jax/issues/9709 

32 

33 Args: 

34 N: total items 

35 k: # of items to choose 

36 

37 Returns: 

38 NCk: N choose k 

39 """ 

40 one = 1 

41 N_plus_1 = lax.add(N, one) 

42 k_plus_1 = lax.add(k, one) 

43 return lax.exp( 

44 lax.sub( 

45 gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k))) 

46 ) 

47 ) 

48 

49 

50def complex_to_real_iso_matrix(A): 

51 return jnp.block([[jnp.real(A), -jnp.imag(A)], [jnp.imag(A), jnp.real(A)]]) 

52 

53 

54def real_to_complex_iso_matrix(A): 

55 N = A.shape[0] 

56 return A[: N // 2, : N // 2] + 1j * A[N // 2 :, : N // 2] 

57 

58 

59def complex_to_real_iso_vector(v): 

60 return jnp.block([[jnp.real(v)], [jnp.imag(v)]]) 

61 

62 

63def real_to_complex_iso_vector(v): 

64 N = v.shape[0] 

65 return v[: N // 2, :] + 1j * v[N // 2 :, :] 

66 

67 

68def imag_times_iso_vector(v): 

69 N = v.shape[0] 

70 return jnp.block([[-v[N // 2 :, :]], [v[: N // 2, :]]]) 

71 

72 

73def imag_times_iso_matrix(A): 

74 N = A.shape[0] 

75 Ar = A[: N // 2, : N // 2] 

76 Ai = A[N // 2 :, : N // 2] 

77 return jnp.block([[-Ai, -Ar], [Ar, -Ai]]) 

78 

79 

80def conj_transpose_iso_matrix(A): 

81 N = A.shape[0] 

82 Ar = A[: N // 2, : N // 2].T 

83 Ai = A[N // 2 :, : N // 2].T 

84 return jnp.block([[Ar, Ai], [-Ai, Ar]]) 

85 

86 

87def robust_isscalar(val): 

88 is_scalar = isinstance(val, Number) or jnp.isscalar(val) 

89 if isinstance(val, Array): 

90 is_scalar = is_scalar or (len(val.shape) == 0) 

91 return is_scalar