Coverage for jaxquantum/utils/hermgauss.py: 16%
63 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1import jax.numpy as jnp
2from jax import jit, lax
3from jax._src.numpy.util import promote_dtypes_inexact
5"""
6The following code is sourced from https://github.com/f0uriest/orthax/
7and is licensed under the MIT license.
9Copyright (c) 2024 Rory Conlin
11Permission is hereby granted, free of charge, to any person obtaining a copy
12of this software and associated documentation files (the "Software"), to deal
13in the Software without restriction, including without limitation the rights
14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15copies of the Software, and to permit persons to whom the Software is
16furnished to do so, subject to the following conditions:
18The above copyright notice and this permission notice shall be included in all
19copies or substantial portions of the Software.
21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27SOFTWARE.
28"""
31def as_series(*arrs):
32 """Return arguments as a list of 1-d arrays.
34 The returned list contains array(s) of dtype double, complex double, or
35 object. A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of
36 size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays
37 of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array
38 raises a Value Error if it is not first reshaped into either a 1-d or 2-d
39 array.
41 Parameters
42 ----------
43 arrs : array_like
44 1- or 2-d array_like
45 trim : boolean, optional
46 When True, trailing zeros are removed from the inputs.
47 When False, the inputs are passed through intact.
49 Returns
50 -------
51 a1, a2,... : 1-D arrays
52 A copy of the input data as 1-d arrays.
54 """
55 arrays = tuple(jnp.array(a, ndmin=1) for a in arrs)
56 arrays = promote_dtypes_inexact(*arrays)
57 if len(arrays) == 1:
58 return arrays[0]
59 return tuple(arrays)
62@jit
63def hermcompanion(c):
64 """Return the scaled companion matrix of c.
66 The basis polynomials are scaled so that the companion matrix is
67 symmetric when `c` is an Hermite basis polynomial. This provides
68 better eigenvalue estimates than the unscaled case and for basis
69 polynomials the eigenvalues are guaranteed to be real if
70 `jax.numpy.linalg.eigvalsh` is used to obtain them.
72 Parameters
73 ----------
74 c : array_like
75 1-D array of Hermite series coefficients ordered from low to high
76 degree.
78 Returns
79 -------
80 mat : ndarray
81 Scaled companion matrix of dimensions (deg, deg).
83 """
84 c = as_series(c)
85 if len(c) < 2:
86 raise ValueError("Series must have maximum degree of at least 1.")
87 if len(c) == 2:
88 return jnp.array([[-0.5 * c[0] / c[1]]])
90 n = len(c) - 1
91 mat = jnp.zeros((n, n), dtype=c.dtype)
92 scl = jnp.hstack((1.0, 1.0 / jnp.sqrt(2.0 * jnp.arange(n - 1, 0, -1))))
93 scl = jnp.cumprod(scl)[::-1]
94 shp = mat.shape
95 mat = mat.flatten()
96 mat = mat.at[1 :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
97 mat = mat.at[n :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
98 mat = mat.reshape(shp)
99 mat = mat.at[:, -1].add(-scl * c[:-1] / (2.0 * c[-1]))
100 return mat
103@jit
104def _normed_hermite_n(x, n):
105 """
106 Evaluate a normalized Hermite polynomial.
108 Compute the value of the normalized Hermite polynomial of degree ``n``
109 at the points ``x``.
112 Parameters
113 ----------
114 x : ndarray of double.
115 Points at which to evaluate the function
116 n : int
117 Degree of the normalized Hermite function to be evaluated.
119 Returns
120 -------
121 values : ndarray
122 The shape of the return value is described above.
124 Notes
125 -----
126 This function is needed for finding the Gauss points and integration
127 weights for high degrees. The values of the standard Hermite functions
128 overflow when n >= 207.
130 """
132 def truefun():
133 return jnp.full(x.shape, 1 / jnp.sqrt(jnp.sqrt(jnp.pi)))
135 def falsefun():
136 c0 = jnp.zeros_like(x)
137 c1 = jnp.ones_like(x) / jnp.sqrt(jnp.sqrt(jnp.pi))
138 nd = jnp.array(n).astype(float)
140 def body(i, val):
141 c0, c1, nd = val
142 tmp = c0
143 c0 = -c1 * jnp.sqrt((nd - 1.0) / nd)
144 c1 = tmp + c1 * x * jnp.sqrt(2.0 / nd)
145 nd = nd - 1.0
146 return c0, c1, nd
148 c0, c1, _ = lax.fori_loop(0, n - 1, body, (c0, c1, nd))
149 return c0 + c1 * x * jnp.sqrt(2)
151 return lax.cond(n == 0, truefun, falsefun)
154def hermgauss(deg):
155 r"""Gauss-Hermite quadrature.
157 Computes the sample points and weights for Gauss-Hermite quadrature.
158 These sample points and weights will correctly integrate polynomials of
159 degree :math:`2*deg - 1` or less over the interval :math:`[-\inf, \inf]`
160 with the weight function :math:`f(x) = \exp(-x^2)`.
162 Parameters
163 ----------
164 deg : int
165 Number of sample points and weights. It must be >= 1.
167 Returns
168 -------
169 x : ndarray
170 1-D ndarray containing the sample points.
171 y : ndarray
172 1-D ndarray containing the weights.
174 Notes
175 -----
176 The results have only been tested up to degree 100, higher degrees may
177 be problematic. The weights are determined by using the fact that
179 .. math:: w_k = c / (H'_n(x_k) * H_{n-1}(x_k))
181 where :math:`c` is a constant independent of :math:`k` and :math:`x_k`
182 is the k'th root of :math:`H_n`, and then scaling the results to get
183 the right value when integrating 1.
185 """
186 deg = int(deg)
187 if deg <= 0:
188 raise ValueError("deg must be a positive integer")
190 # first approximation of roots. We use the fact that the companion
191 # matrix is symmetric in this case in order to obtain better zeros.
192 c = jnp.zeros(deg + 1).at[-1].set(1)
193 m = hermcompanion(c)
194 x = jnp.linalg.eigvalsh(m)
196 # improve roots by one application of Newton
197 dy = _normed_hermite_n(x, deg)
198 df = _normed_hermite_n(x, deg - 1) * jnp.sqrt(2 * deg)
199 x -= dy / df
201 # compute the weights. We scale the factor to avoid possible numerical
202 # overflow.
203 fm = _normed_hermite_n(x, deg - 1)
204 fm /= jnp.abs(fm).max()
205 w = 1 / (fm * fm)
207 # for Hermite we can also symmetrize
208 w = (w + w[::-1]) / 2
209 x = (x - x[::-1]) / 2
211 # scale w to get the right value
212 w *= jnp.sqrt(jnp.pi) / w.sum()
214 return x, w