Skip to content

utils

JAX Utils

comb(N, k)

NCk

TODO: replace with jsp.special.comb once issue is closed:

https://github.com/google/jax/issues/9709

Parameters:

Name Type Description Default
N

total items

required
k

of items to choose

required

Returns:

Name Type Description
NCk

N choose k

Source code in jaxquantum/utils/utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def comb(N, k):
    """
    NCk

    #TODO: replace with jsp.special.comb once issue is closed:
    https://github.com/google/jax/issues/9709

    Args:
        N: total items
        k: # of items to choose

    Returns:
        NCk: N choose k
    """
    one = 1
    N_plus_1 = lax.add(N, one)
    k_plus_1 = lax.add(k, one)
    return lax.exp(
        lax.sub(
            gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k)))
        )
    )

set_precision(precision)

Set the precision of JAX operations.

Parameters:

Name Type Description Default
precision Literal['single', 'double']

'single' or 'double'

required

Raises:

Type Description
ValueError

if precision is not 'single' or 'double'

Source code in jaxquantum/utils/utils.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def set_precision(precision: Literal["single", "double"]):
    """
    Set the precision of JAX operations.

    Args:
        precision: 'single' or 'double'

    Raises:
        ValueError: if precision is not 'single' or 'double'
    """
    if precision == "single":
        config.update("jax_enable_x64", False)
    elif precision == "double":
        config.update("jax_enable_x64", True)
    else:
        raise ValueError("precision must be 'single' or 'double'")