Coverage for jaxquantum/utils/utils.py: 91%
55 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1"""
2JAX Utils
3"""
5from numbers import Number
6from typing import Dict
8from jax import lax, Array, device_put, config
9from jax._src.scipy.special import gammaln
10import jax.numpy as jnp
11import numpy as np
13from typing import Literal
15config.update("jax_enable_x64", True)
18def device_put_params(params: Dict, non_device_params=None):
19 non_device_params = [] if non_device_params is None else non_device_params
20 for param in params:
21 if param in non_device_params:
22 continue
23 if isinstance(params[param], Number) or isinstance(params[param], np.ndarray):
24 params[param] = device_put(params[param])
25 return params
28def comb(N, k):
29 """
30 NCk
32 #TODO: replace with jsp.special.comb once issue is closed:
33 https://github.com/google/jax/issues/9709
35 Args:
36 N: total items
37 k: # of items to choose
39 Returns:
40 NCk: N choose k
41 """
42 one = 1
43 N_plus_1 = lax.add(N, one)
44 k_plus_1 = lax.add(k, one)
45 return lax.exp(
46 lax.sub(
47 gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k)))
48 )
49 )
52def complex_to_real_iso_matrix(A):
53 return jnp.block([[jnp.real(A), -jnp.imag(A)], [jnp.imag(A), jnp.real(A)]])
56def real_to_complex_iso_matrix(A):
57 N = A.shape[0]
58 return A[: N // 2, : N // 2] + 1j * A[N // 2 :, : N // 2]
61def complex_to_real_iso_vector(v):
62 return jnp.block([[jnp.real(v)], [jnp.imag(v)]])
65def real_to_complex_iso_vector(v):
66 N = v.shape[0]
67 return v[: N // 2, :] + 1j * v[N // 2 :, :]
70def imag_times_iso_vector(v):
71 N = v.shape[0]
72 return jnp.block([[-v[N // 2 :, :]], [v[: N // 2, :]]])
75def imag_times_iso_matrix(A):
76 N = A.shape[0]
77 Ar = A[: N // 2, : N // 2]
78 Ai = A[N // 2 :, : N // 2]
79 return jnp.block([[-Ai, -Ar], [Ar, -Ai]])
82def conj_transpose_iso_matrix(A):
83 N = A.shape[0]
84 Ar = A[: N // 2, : N // 2].T
85 Ai = A[N // 2 :, : N // 2].T
86 return jnp.block([[Ar, Ai], [-Ai, Ar]])
89def robust_isscalar(val):
90 is_scalar = isinstance(val, Number) or jnp.isscalar(val)
91 if isinstance(val, Array):
92 is_scalar = is_scalar or (len(val.shape) == 0)
93 return is_scalar
96# =====================================================
98# Precision
100def set_precision(precision: Literal["single", "double"]):
101 """
102 Set the precision of JAX operations.
104 Args:
105 precision: 'single' or 'double'
107 Raises:
108 ValueError: if precision is not 'single' or 'double'
109 """
110 if precision == "single":
111 config.update("jax_enable_x64", False)
112 elif precision == "double":
113 config.update("jax_enable_x64", True)
114 else:
115 raise ValueError("precision must be 'single' or 'double'")