Skip to content

utils

Utility functions

harm_osc_wavefunction(n, x, l_osc)

Taken from scqubits... not jit-able

For given quantum number n=0,1,2,... return the value of the harmonic oscillator wave function :math:\psi_n(x) = N H_n(x/l_{osc}) \exp(-x^2/2l_\text{ osc}), N being the proper normalization factor.

Directly uses scipy.special.pbdv (implementation of the parabolic cylinder function) to mitigate numerical stability issues with the more commonly used expression in terms of a Gaussian and a Hermite polynomial factor.

Parameters

n: index of wave function, n=0 is ground state x: coordinate(s) where wave function is evaluated l_osc: oscillator length, defined via <0|x^2|0> = l_osc^2/2

Returns

value of harmonic oscillator wave function
Source code in jaxquantum/devices/common/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def harm_osc_wavefunction(n, x, l_osc):
    r"""
    Taken from scqubits... not jit-able

    For given quantum number n=0,1,2,... return the value of the harmonic
    oscillator wave function :math:`\psi_n(x) = N H_n(x/l_{osc}) \exp(-x^2/2l_\text{
    osc})`, N being the proper normalization factor.

    Directly uses `scipy.special.pbdv` (implementation of the parabolic cylinder
    function) to mitigate numerical stability issues with the more commonly used
    expression in terms of a Gaussian and a Hermite polynomial factor.

    Parameters
    ----------
    n:
        index of wave function, n=0 is ground state
    x:
        coordinate(s) where wave function is evaluated
    l_osc:
        oscillator length, defined via <0|x^2|0> = l_osc^2/2

    Returns
    -------
        value of harmonic oscillator wave function
    """
    x = 2 * jnp.pi * x
    result = pbdv(n, jnp.sqrt(2.0) * x / l_osc)[0]
    result = result / jnp.sqrt(l_osc * jnp.sqrt(jnp.pi) * factorial_approx(n))
    return result