Coverage for jaxquantum/utils/utils.py: 100%
48 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2JAX Utils
3"""
5from numbers import Number
6from typing import Dict
8from jax import lax, Array, device_put, config
9from jax._src.scipy.special import gammaln
10import jax.numpy as jnp
11import numpy as np
13config.update("jax_enable_x64", True)
16def device_put_params(params: Dict, non_device_params=None):
17 non_device_params = [] if non_device_params is None else non_device_params
18 for param in params:
19 if param in non_device_params:
20 continue
21 if isinstance(params[param], Number) or isinstance(params[param], np.ndarray):
22 params[param] = device_put(params[param])
23 return params
26def comb(N, k):
27 """
28 NCk
30 #TODO: replace with jsp.special.comb once issue is closed:
31 https://github.com/google/jax/issues/9709
33 Args:
34 N: total items
35 k: # of items to choose
37 Returns:
38 NCk: N choose k
39 """
40 one = 1
41 N_plus_1 = lax.add(N, one)
42 k_plus_1 = lax.add(k, one)
43 return lax.exp(
44 lax.sub(
45 gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k)))
46 )
47 )
50def complex_to_real_iso_matrix(A):
51 return jnp.block([[jnp.real(A), -jnp.imag(A)], [jnp.imag(A), jnp.real(A)]])
54def real_to_complex_iso_matrix(A):
55 N = A.shape[0]
56 return A[: N // 2, : N // 2] + 1j * A[N // 2 :, : N // 2]
59def complex_to_real_iso_vector(v):
60 return jnp.block([[jnp.real(v)], [jnp.imag(v)]])
63def real_to_complex_iso_vector(v):
64 N = v.shape[0]
65 return v[: N // 2, :] + 1j * v[N // 2 :, :]
68def imag_times_iso_vector(v):
69 N = v.shape[0]
70 return jnp.block([[-v[N // 2 :, :]], [v[: N // 2, :]]])
73def imag_times_iso_matrix(A):
74 N = A.shape[0]
75 Ar = A[: N // 2, : N // 2]
76 Ai = A[N // 2 :, : N // 2]
77 return jnp.block([[-Ai, -Ar], [Ar, -Ai]])
80def conj_transpose_iso_matrix(A):
81 N = A.shape[0]
82 Ar = A[: N // 2, : N // 2].T
83 Ai = A[N // 2 :, : N // 2].T
84 return jnp.block([[Ar, Ai], [-Ai, Ar]])
87def robust_isscalar(val):
88 is_scalar = isinstance(val, Number) or jnp.isscalar(val)
89 if isinstance(val, Array):
90 is_scalar = is_scalar or (len(val.shape) == 0)
91 return is_scalar