Coverage for jaxquantum/utils/hermgauss.py: 16%

63 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1import jax.numpy as jnp 

2from jax import jit, lax 

3from jax._src.numpy.util import promote_dtypes_inexact 

4 

5""" 

6The following code is sourced from https://github.com/f0uriest/orthax/ 

7and is licensed under the MIT license. 

8 

9Copyright (c) 2024 Rory Conlin 

10 

11Permission is hereby granted, free of charge, to any person obtaining a copy 

12of this software and associated documentation files (the "Software"), to deal 

13in the Software without restriction, including without limitation the rights 

14to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

15copies of the Software, and to permit persons to whom the Software is 

16furnished to do so, subject to the following conditions: 

17 

18The above copyright notice and this permission notice shall be included in all 

19copies or substantial portions of the Software. 

20 

21THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

22IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

23FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

24AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

25LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

26OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

27SOFTWARE. 

28""" 

29 

30 

31def as_series(*arrs): 

32 """Return arguments as a list of 1-d arrays. 

33 

34 The returned list contains array(s) of dtype double, complex double, or 

35 object. A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of 

36 size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays 

37 of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array 

38 raises a Value Error if it is not first reshaped into either a 1-d or 2-d 

39 array. 

40 

41 Parameters 

42 ---------- 

43 arrs : array_like 

44 1- or 2-d array_like 

45 trim : boolean, optional 

46 When True, trailing zeros are removed from the inputs. 

47 When False, the inputs are passed through intact. 

48 

49 Returns 

50 ------- 

51 a1, a2,... : 1-D arrays 

52 A copy of the input data as 1-d arrays. 

53 

54 """ 

55 arrays = tuple(jnp.array(a, ndmin=1) for a in arrs) 

56 arrays = promote_dtypes_inexact(*arrays) 

57 if len(arrays) == 1: 

58 return arrays[0] 

59 return tuple(arrays) 

60 

61 

62@jit 

63def hermcompanion(c): 

64 """Return the scaled companion matrix of c. 

65 

66 The basis polynomials are scaled so that the companion matrix is 

67 symmetric when `c` is an Hermite basis polynomial. This provides 

68 better eigenvalue estimates than the unscaled case and for basis 

69 polynomials the eigenvalues are guaranteed to be real if 

70 `jax.numpy.linalg.eigvalsh` is used to obtain them. 

71 

72 Parameters 

73 ---------- 

74 c : array_like 

75 1-D array of Hermite series coefficients ordered from low to high 

76 degree. 

77 

78 Returns 

79 ------- 

80 mat : ndarray 

81 Scaled companion matrix of dimensions (deg, deg). 

82 

83 """ 

84 c = as_series(c) 

85 if len(c) < 2: 

86 raise ValueError("Series must have maximum degree of at least 1.") 

87 if len(c) == 2: 

88 return jnp.array([[-0.5 * c[0] / c[1]]]) 

89 

90 n = len(c) - 1 

91 mat = jnp.zeros((n, n), dtype=c.dtype) 

92 scl = jnp.hstack((1.0, 1.0 / jnp.sqrt(2.0 * jnp.arange(n - 1, 0, -1)))) 

93 scl = jnp.cumprod(scl)[::-1] 

94 shp = mat.shape 

95 mat = mat.flatten() 

96 mat = mat.at[1 :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n))) 

97 mat = mat.at[n :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n))) 

98 mat = mat.reshape(shp) 

99 mat = mat.at[:, -1].add(-scl * c[:-1] / (2.0 * c[-1])) 

100 return mat 

101 

102 

103@jit 

104def _normed_hermite_n(x, n): 

105 """ 

106 Evaluate a normalized Hermite polynomial. 

107 

108 Compute the value of the normalized Hermite polynomial of degree ``n`` 

109 at the points ``x``. 

110 

111 

112 Parameters 

113 ---------- 

114 x : ndarray of double. 

115 Points at which to evaluate the function 

116 n : int 

117 Degree of the normalized Hermite function to be evaluated. 

118 

119 Returns 

120 ------- 

121 values : ndarray 

122 The shape of the return value is described above. 

123 

124 Notes 

125 ----- 

126 This function is needed for finding the Gauss points and integration 

127 weights for high degrees. The values of the standard Hermite functions 

128 overflow when n >= 207. 

129 

130 """ 

131 

132 def truefun(): 

133 return jnp.full(x.shape, 1 / jnp.sqrt(jnp.sqrt(jnp.pi))) 

134 

135 def falsefun(): 

136 c0 = jnp.zeros_like(x) 

137 c1 = jnp.ones_like(x) / jnp.sqrt(jnp.sqrt(jnp.pi)) 

138 nd = jnp.array(n).astype(float) 

139 

140 def body(i, val): 

141 c0, c1, nd = val 

142 tmp = c0 

143 c0 = -c1 * jnp.sqrt((nd - 1.0) / nd) 

144 c1 = tmp + c1 * x * jnp.sqrt(2.0 / nd) 

145 nd = nd - 1.0 

146 return c0, c1, nd 

147 

148 c0, c1, _ = lax.fori_loop(0, n - 1, body, (c0, c1, nd)) 

149 return c0 + c1 * x * jnp.sqrt(2) 

150 

151 return lax.cond(n == 0, truefun, falsefun) 

152 

153 

154def hermgauss(deg): 

155 r"""Gauss-Hermite quadrature. 

156 

157 Computes the sample points and weights for Gauss-Hermite quadrature. 

158 These sample points and weights will correctly integrate polynomials of 

159 degree :math:`2*deg - 1` or less over the interval :math:`[-\inf, \inf]` 

160 with the weight function :math:`f(x) = \exp(-x^2)`. 

161 

162 Parameters 

163 ---------- 

164 deg : int 

165 Number of sample points and weights. It must be >= 1. 

166 

167 Returns 

168 ------- 

169 x : ndarray 

170 1-D ndarray containing the sample points. 

171 y : ndarray 

172 1-D ndarray containing the weights. 

173 

174 Notes 

175 ----- 

176 The results have only been tested up to degree 100, higher degrees may 

177 be problematic. The weights are determined by using the fact that 

178 

179 .. math:: w_k = c / (H'_n(x_k) * H_{n-1}(x_k)) 

180 

181 where :math:`c` is a constant independent of :math:`k` and :math:`x_k` 

182 is the k'th root of :math:`H_n`, and then scaling the results to get 

183 the right value when integrating 1. 

184 

185 """ 

186 deg = int(deg) 

187 if deg <= 0: 

188 raise ValueError("deg must be a positive integer") 

189 

190 # first approximation of roots. We use the fact that the companion 

191 # matrix is symmetric in this case in order to obtain better zeros. 

192 c = jnp.zeros(deg + 1).at[-1].set(1) 

193 m = hermcompanion(c) 

194 x = jnp.linalg.eigvalsh(m) 

195 

196 # improve roots by one application of Newton 

197 dy = _normed_hermite_n(x, deg) 

198 df = _normed_hermite_n(x, deg - 1) * jnp.sqrt(2 * deg) 

199 x -= dy / df 

200 

201 # compute the weights. We scale the factor to avoid possible numerical 

202 # overflow. 

203 fm = _normed_hermite_n(x, deg - 1) 

204 fm /= jnp.abs(fm).max() 

205 w = 1 / (fm * fm) 

206 

207 # for Hermite we can also symmetrize 

208 w = (w + w[::-1]) / 2 

209 x = (x - x[::-1]) / 2 

210 

211 # scale w to get the right value 

212 w *= jnp.sqrt(jnp.pi) / w.sum() 

213 

214 return x, w