Skip to content

utils

Utils

as_series(*arrs)

Return arguments as a list of 1-d arrays.

The returned list contains array(s) of dtype double, complex double, or object. A 1-d argument of shape (N,) is parsed into N arrays of size one; a 2-d argument of shape (M,N) is parsed into M arrays of size N (i.e., is "parsed by row"); and a higher dimensional array raises a Value Error if it is not first reshaped into either a 1-d or 2-d array.

Parameters

arrs : array_like 1- or 2-d array_like trim : boolean, optional When True, trailing zeros are removed from the inputs. When False, the inputs are passed through intact.

Returns

a1, a2,... : 1-D arrays A copy of the input data as 1-d arrays.

Source code in jaxquantum/utils/hermgauss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def as_series(*arrs):
    """Return arguments as a list of 1-d arrays.

    The returned list contains array(s) of dtype double, complex double, or
    object.  A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of
    size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays
    of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array
    raises a Value Error if it is not first reshaped into either a 1-d or 2-d
    array.

    Parameters
    ----------
    arrs : array_like
        1- or 2-d array_like
    trim : boolean, optional
        When True, trailing zeros are removed from the inputs.
        When False, the inputs are passed through intact.

    Returns
    -------
    a1, a2,... : 1-D arrays
        A copy of the input data as 1-d arrays.

    """
    arrays = tuple(jnp.array(a, ndmin=1) for a in arrs)
    arrays = promote_dtypes_inexact(*arrays)
    if len(arrays) == 1:
        return arrays[0]
    return tuple(arrays)

comb(N, k)

NCk

TODO: replace with jsp.special.comb once issue is closed:

https://github.com/google/jax/issues/9709

Parameters:

Name Type Description Default
N

total items

required
k

of items to choose

required

Returns:

Name Type Description
NCk

N choose k

Source code in jaxquantum/utils/utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def comb(N, k):
    """
    NCk

    #TODO: replace with jsp.special.comb once issue is closed:
    https://github.com/google/jax/issues/9709

    Args:
        N: total items
        k: # of items to choose

    Returns:
        NCk: N choose k
    """
    one = 1
    N_plus_1 = lax.add(N, one)
    k_plus_1 = lax.add(k, one)
    return lax.exp(
        lax.sub(
            gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k)))
        )
    )

hermcompanion(c)

Return the scaled companion matrix of c.

The basis polynomials are scaled so that the companion matrix is symmetric when c is an Hermite basis polynomial. This provides better eigenvalue estimates than the unscaled case and for basis polynomials the eigenvalues are guaranteed to be real if jax.numpy.linalg.eigvalsh is used to obtain them.

Parameters

c : array_like 1-D array of Hermite series coefficients ordered from low to high degree.

Returns

mat : ndarray Scaled companion matrix of dimensions (deg, deg).

Source code in jaxquantum/utils/hermgauss.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@jit
def hermcompanion(c):
    """Return the scaled companion matrix of c.

    The basis polynomials are scaled so that the companion matrix is
    symmetric when `c` is an Hermite basis polynomial. This provides
    better eigenvalue estimates than the unscaled case and for basis
    polynomials the eigenvalues are guaranteed to be real if
    `jax.numpy.linalg.eigvalsh` is used to obtain them.

    Parameters
    ----------
    c : array_like
        1-D array of Hermite series coefficients ordered from low to high
        degree.

    Returns
    -------
    mat : ndarray
        Scaled companion matrix of dimensions (deg, deg).

    """
    c = as_series(c)
    if len(c) < 2:
        raise ValueError("Series must have maximum degree of at least 1.")
    if len(c) == 2:
        return jnp.array([[-0.5 * c[0] / c[1]]])

    n = len(c) - 1
    mat = jnp.zeros((n, n), dtype=c.dtype)
    scl = jnp.hstack((1.0, 1.0 / jnp.sqrt(2.0 * jnp.arange(n - 1, 0, -1))))
    scl = jnp.cumprod(scl)[::-1]
    shp = mat.shape
    mat = mat.flatten()
    mat = mat.at[1 :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
    mat = mat.at[n :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
    mat = mat.reshape(shp)
    mat = mat.at[:, -1].add(-scl * c[:-1] / (2.0 * c[-1]))
    return mat

set_precision(precision)

Set the precision of JAX operations.

Parameters:

Name Type Description Default
precision Literal['single', 'double']

'single' or 'double'

required

Raises:

Type Description
ValueError

if precision is not 'single' or 'double'

Source code in jaxquantum/utils/utils.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def set_precision(precision: Literal["single", "double"]):
    """
    Set the precision of JAX operations.

    Args:
        precision: 'single' or 'double'

    Raises:
        ValueError: if precision is not 'single' or 'double'
    """
    if precision == "single":
        config.update("jax_enable_x64", False)
    elif precision == "double":
        config.update("jax_enable_x64", True)
    else:
        raise ValueError("precision must be 'single' or 'double'")