Skip to content

utils

JAX Utils

comb(N, k)

NCk

TODO: replace with jsp.special.comb once issue is closed:

https://github.com/google/jax/issues/9709

Parameters:

Name Type Description Default
N

total items

required
k

of items to choose

required

Returns:

Name Type Description
NCk

N choose k

Source code in jaxquantum/utils/utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def comb(N, k):
    """
    NCk

    #TODO: replace with jsp.special.comb once issue is closed:
    https://github.com/google/jax/issues/9709

    Args:
        N: total items
        k: # of items to choose

    Returns:
        NCk: N choose k
    """
    one = 1
    N_plus_1 = lax.add(N, one)
    k_plus_1 = lax.add(k, one)
    return lax.exp(
        lax.sub(
            gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k)))
        )
    )