Skip to content

hermgauss

as_series(*arrs)

Return arguments as a list of 1-d arrays.

The returned list contains array(s) of dtype double, complex double, or object. A 1-d argument of shape (N,) is parsed into N arrays of size one; a 2-d argument of shape (M,N) is parsed into M arrays of size N (i.e., is "parsed by row"); and a higher dimensional array raises a Value Error if it is not first reshaped into either a 1-d or 2-d array.

Parameters

arrs : array_like 1- or 2-d array_like trim : boolean, optional When True, trailing zeros are removed from the inputs. When False, the inputs are passed through intact.

Returns

a1, a2,... : 1-D arrays A copy of the input data as 1-d arrays.

Source code in jaxquantum/utils/hermgauss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def as_series(*arrs):
    """Return arguments as a list of 1-d arrays.

    The returned list contains array(s) of dtype double, complex double, or
    object.  A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of
    size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays
    of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array
    raises a Value Error if it is not first reshaped into either a 1-d or 2-d
    array.

    Parameters
    ----------
    arrs : array_like
        1- or 2-d array_like
    trim : boolean, optional
        When True, trailing zeros are removed from the inputs.
        When False, the inputs are passed through intact.

    Returns
    -------
    a1, a2,... : 1-D arrays
        A copy of the input data as 1-d arrays.

    """
    arrays = tuple(jnp.array(a, ndmin=1) for a in arrs)
    arrays = promote_dtypes_inexact(*arrays)
    if len(arrays) == 1:
        return arrays[0]
    return tuple(arrays)

hermcompanion(c)

Return the scaled companion matrix of c.

The basis polynomials are scaled so that the companion matrix is symmetric when c is an Hermite basis polynomial. This provides better eigenvalue estimates than the unscaled case and for basis polynomials the eigenvalues are guaranteed to be real if jax.numpy.linalg.eigvalsh is used to obtain them.

Parameters

c : array_like 1-D array of Hermite series coefficients ordered from low to high degree.

Returns

mat : ndarray Scaled companion matrix of dimensions (deg, deg).

Source code in jaxquantum/utils/hermgauss.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@jit
def hermcompanion(c):
    """Return the scaled companion matrix of c.

    The basis polynomials are scaled so that the companion matrix is
    symmetric when `c` is an Hermite basis polynomial. This provides
    better eigenvalue estimates than the unscaled case and for basis
    polynomials the eigenvalues are guaranteed to be real if
    `jax.numpy.linalg.eigvalsh` is used to obtain them.

    Parameters
    ----------
    c : array_like
        1-D array of Hermite series coefficients ordered from low to high
        degree.

    Returns
    -------
    mat : ndarray
        Scaled companion matrix of dimensions (deg, deg).

    """
    c = as_series(c)
    if len(c) < 2:
        raise ValueError("Series must have maximum degree of at least 1.")
    if len(c) == 2:
        return jnp.array([[-0.5 * c[0] / c[1]]])

    n = len(c) - 1
    mat = jnp.zeros((n, n), dtype=c.dtype)
    scl = jnp.hstack((1.0, 1.0 / jnp.sqrt(2.0 * jnp.arange(n - 1, 0, -1))))
    scl = jnp.cumprod(scl)[::-1]
    shp = mat.shape
    mat = mat.flatten()
    mat = mat.at[1 :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
    mat = mat.at[n :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
    mat = mat.reshape(shp)
    mat = mat.at[:, -1].add(-scl * c[:-1] / (2.0 * c[-1]))
    return mat

hermgauss(deg)

Gauss-Hermite quadrature.

Computes the sample points and weights for Gauss-Hermite quadrature. These sample points and weights will correctly integrate polynomials of degree :math:2*deg - 1 or less over the interval :math:[-\inf, \inf] with the weight function :math:f(x) = \exp(-x^2).

Parameters

deg : int Number of sample points and weights. It must be >= 1.

Returns

x : ndarray 1-D ndarray containing the sample points. y : ndarray 1-D ndarray containing the weights.

Notes

The results have only been tested up to degree 100, higher degrees may be problematic. The weights are determined by using the fact that

.. math:: w_k = c / (H'n(x_k) * H(x_k))

where :math:c is a constant independent of :math:k and :math:x_k is the k'th root of :math:H_n, and then scaling the results to get the right value when integrating 1.

Source code in jaxquantum/utils/hermgauss.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def hermgauss(deg):
    r"""Gauss-Hermite quadrature.

    Computes the sample points and weights for Gauss-Hermite quadrature.
    These sample points and weights will correctly integrate polynomials of
    degree :math:`2*deg - 1` or less over the interval :math:`[-\inf, \inf]`
    with the weight function :math:`f(x) = \exp(-x^2)`.

    Parameters
    ----------
    deg : int
        Number of sample points and weights. It must be >= 1.

    Returns
    -------
    x : ndarray
        1-D ndarray containing the sample points.
    y : ndarray
        1-D ndarray containing the weights.

    Notes
    -----
    The results have only been tested up to degree 100, higher degrees may
    be problematic. The weights are determined by using the fact that

    .. math:: w_k = c / (H'_n(x_k) * H_{n-1}(x_k))

    where :math:`c` is a constant independent of :math:`k` and :math:`x_k`
    is the k'th root of :math:`H_n`, and then scaling the results to get
    the right value when integrating 1.

    """
    deg = int(deg)
    if deg <= 0:
        raise ValueError("deg must be a positive integer")

    # first approximation of roots. We use the fact that the companion
    # matrix is symmetric in this case in order to obtain better zeros.
    c = jnp.zeros(deg + 1).at[-1].set(1)
    m = hermcompanion(c)
    x = jnp.linalg.eigvalsh(m)

    # improve roots by one application of Newton
    dy = _normed_hermite_n(x, deg)
    df = _normed_hermite_n(x, deg - 1) * jnp.sqrt(2 * deg)
    x -= dy / df

    # compute the weights. We scale the factor to avoid possible numerical
    # overflow.
    fm = _normed_hermite_n(x, deg - 1)
    fm /= jnp.abs(fm).max()
    w = 1 / (fm * fm)

    # for Hermite we can also symmetrize
    w = (w + w[::-1]) / 2
    x = (x - x[::-1]) / 2

    # scale w to get the right value
    w *= jnp.sqrt(jnp.pi) / w.sum()

    return x, w