Skip to content

utils

Utils

Ec_to_inv_pF(Ec)

GHz -> 1/picoFarad

Source code in jaxquantum/utils/units.py
81
82
83
84
85
86
87
88
def Ec_to_inv_pF(Ec):
    """
    GHz -> 1/picoFarad
    """
    joule = GHz_to_joule(Ec)
    Gjoule = joule / 1e9
    inv_nFarad = Gjoule / ((constants.e) ** 2 / (2))
    return inv_nFarad * 1e-3

as_series(*arrs)

Return arguments as a list of 1-d arrays.

The returned list contains array(s) of dtype double, complex double, or object. A 1-d argument of shape (N,) is parsed into N arrays of size one; a 2-d argument of shape (M,N) is parsed into M arrays of size N (i.e., is "parsed by row"); and a higher dimensional array raises a Value Error if it is not first reshaped into either a 1-d or 2-d array.

Parameters

arrs : array_like 1- or 2-d array_like trim : boolean, optional When True, trailing zeros are removed from the inputs. When False, the inputs are passed through intact.

Returns

a1, a2,... : 1-D arrays A copy of the input data as 1-d arrays.

Source code in jaxquantum/utils/hermgauss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def as_series(*arrs):
    """Return arguments as a list of 1-d arrays.

    The returned list contains array(s) of dtype double, complex double, or
    object.  A 1-d argument of shape ``(N,)`` is parsed into ``N`` arrays of
    size one; a 2-d argument of shape ``(M,N)`` is parsed into ``M`` arrays
    of size ``N`` (i.e., is "parsed by row"); and a higher dimensional array
    raises a Value Error if it is not first reshaped into either a 1-d or 2-d
    array.

    Parameters
    ----------
    arrs : array_like
        1- or 2-d array_like
    trim : boolean, optional
        When True, trailing zeros are removed from the inputs.
        When False, the inputs are passed through intact.

    Returns
    -------
    a1, a2,... : 1-D arrays
        A copy of the input data as 1-d arrays.

    """
    arrays = tuple(jnp.array(a, ndmin=1) for a in arrs)
    arrays = promote_dtypes_inexact(*arrays)
    if len(arrays) == 1:
        return arrays[0]
    return tuple(arrays)

comb(N, k)

NCk

TODO: replace with jsp.special.comb once issue is closed:

https://github.com/google/jax/issues/9709

Parameters:

Name Type Description Default
N

total items

required
k

of items to choose

required

Returns:

Name Type Description
NCk

N choose k

Source code in jaxquantum/utils/utils.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def comb(N, k):
    """
    NCk

    #TODO: replace with jsp.special.comb once issue is closed:
    https://github.com/google/jax/issues/9709

    Args:
        N: total items
        k: # of items to choose

    Returns:
        NCk: N choose k
    """
    one = 1
    N_plus_1 = lax.add(N, one)
    k_plus_1 = lax.add(k, one)
    return lax.exp(
        lax.sub(
            gammaln(N_plus_1), lax.add(gammaln(k_plus_1), gammaln(lax.sub(N_plus_1, k)))
        )
    )

hermcompanion(c)

Return the scaled companion matrix of c.

The basis polynomials are scaled so that the companion matrix is symmetric when c is an Hermite basis polynomial. This provides better eigenvalue estimates than the unscaled case and for basis polynomials the eigenvalues are guaranteed to be real if jax.numpy.linalg.eigvalsh is used to obtain them.

Parameters

c : array_like 1-D array of Hermite series coefficients ordered from low to high degree.

Returns

mat : ndarray Scaled companion matrix of dimensions (deg, deg).

Source code in jaxquantum/utils/hermgauss.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@jit
def hermcompanion(c):
    """Return the scaled companion matrix of c.

    The basis polynomials are scaled so that the companion matrix is
    symmetric when `c` is an Hermite basis polynomial. This provides
    better eigenvalue estimates than the unscaled case and for basis
    polynomials the eigenvalues are guaranteed to be real if
    `jax.numpy.linalg.eigvalsh` is used to obtain them.

    Parameters
    ----------
    c : array_like
        1-D array of Hermite series coefficients ordered from low to high
        degree.

    Returns
    -------
    mat : ndarray
        Scaled companion matrix of dimensions (deg, deg).

    """
    c = as_series(c)
    if len(c) < 2:
        raise ValueError("Series must have maximum degree of at least 1.")
    if len(c) == 2:
        return jnp.array([[-0.5 * c[0] / c[1]]])

    n = len(c) - 1
    mat = jnp.zeros((n, n), dtype=c.dtype)
    scl = jnp.hstack((1.0, 1.0 / jnp.sqrt(2.0 * jnp.arange(n - 1, 0, -1))))
    scl = jnp.cumprod(scl)[::-1]
    shp = mat.shape
    mat = mat.flatten()
    mat = mat.at[1 :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
    mat = mat.at[n :: n + 1].set(jnp.sqrt(0.5 * jnp.arange(1, n)))
    mat = mat.reshape(shp)
    mat = mat.at[:, -1].add(-scl * c[:-1] / (2.0 * c[-1]))
    return mat

inductance_to_inductive_energy(L)

Convert inductance to inductive energy E_L.

Parameters:

Name Type Description Default
L float

Inductance in nH.

required

Returns:

Name Type Description
float

Inductive energy in GHz.

Source code in jaxquantum/utils/units.py
57
58
59
60
61
62
63
64
65
66
67
68
69
def inductance_to_inductive_energy(L):
    """Convert inductance to inductive energy E_L.

    Args:
        L (float): Inductance in nH.

    Returns:
        float: Inductive energy in GHz.
    """

    inv_L = 1e9 / L
    El_joules = inv_L * (FLUX_QUANTUM**2) / (2 * np.pi) ** 2
    return joule_to_GHz(El_joules)

inductive_energy_to_inductance(El)

Convert inductive energy E_L to inductance.

Parameters:

Name Type Description Default
El float

inductive energy in GHz.

required

Returns:

Name Type Description
float

Inductance in nH.

Source code in jaxquantum/utils/units.py
44
45
46
47
48
49
50
51
52
53
54
55
def inductive_energy_to_inductance(El):
    """Convert inductive energy E_L to inductance.

    Args:
        El (float): inductive energy in GHz.

    Returns:
        float: Inductance in nH.
    """

    inv_L = GHz_to_joule(El) * (2 * np.pi) ** 2 / (FLUX_QUANTUM**2)
    return 1e9 / inv_L

inv_pF_to_Ec(inv_pfarad)

1/picoFarad -> GHz

Source code in jaxquantum/utils/units.py
72
73
74
75
76
77
78
def inv_pF_to_Ec(inv_pfarad):
    """
    1/picoFarad -> GHz
    """
    inv_nFarad = inv_pfarad * 1e3
    Gjoule = (constants.e) ** 2 / (2) * inv_nFarad
    return joule_to_GHz(Gjoule * 1e9)

n_thermal(frequency, temperature)

Calculate the average thermal photon number for a given frequency and temperature.

Parameters:

Name Type Description Default
frequency float

Frequency in GHz.

required
temperature float

Temperature in Kelvin.

required

Returns:

Name Type Description
float float

Average thermal photon number.

Source code in jaxquantum/utils/units.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def n_thermal(frequency: float, temperature: float) -> float:
    """Calculate the average thermal photon number for a given frequency and temperature.

    Args:
        frequency (float): Frequency in GHz.
        temperature (float): Temperature in Kelvin.

    Returns:
        float: Average thermal photon number.
    """
    k_B = constants.k  # Boltzmann constant in J/K
    h = constants.h  # Planck constant in J·s

    exponent = h * (frequency * 1e9) / (k_B * temperature)
    n_avg = 1 / (np.exp(exponent) - 1)
    return n_avg

set_precision(precision)

Set the precision of JAX operations.

Parameters:

Name Type Description Default
precision Literal['single', 'double']

'single' or 'double'

required

Raises:

Type Description
ValueError

if precision is not 'single' or 'double'

Source code in jaxquantum/utils/utils.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def set_precision(precision: Literal["single", "double"]):
    """
    Set the precision of JAX operations.

    Args:
        precision: 'single' or 'double'

    Raises:
        ValueError: if precision is not 'single' or 'double'
    """
    if precision == "single":
        config.update("jax_enable_x64", False)
    elif precision == "double":
        config.update("jax_enable_x64", True)
    else:
        raise ValueError("precision must be 'single' or 'double'")