Skip to content

flux_base

Flux base device.

FluxDevice

Bases: Device

Source code in jaxquantum/devices/superconducting/flux_base.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
@struct.dataclass
class FluxDevice(Device):
    @abstractmethod
    def phi_zpf(self):
        """Return Phase ZPF."""

    def _calculate_wavefunctions_fock(self, phi_vals):
        """Calculate wavefunctions at phi_exts."""
        phi_osc = self.phi_zpf() * jnp.sqrt(2)  # length of oscillator
        phi_vals = jnp.array(phi_vals)

        # calculate basis functions
        basis_functions = []
        for n in range(self.N_pre_diag):
            basis_functions.append(
                harm_osc_wavefunction(n, phi_vals, jnp.real(phi_osc))
            )
        basis_functions = jnp.array(basis_functions)

        # transform to better diagonal basis
        basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis(
            basis_functions
        )

        # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis
        # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix
        wavefunctions = basis_functions_in_H_eigenbasis
        return wavefunctions

    def _calculate_wavefunctions_charge(self, phi_vals):
        phi_vals = jnp.array(phi_vals)

        # calculate basis functions
        basis_functions = []
        n_max = (self.N_pre_diag - 1) // 2
        for n in jnp.arange(-n_max, n_max + 1):
            basis_functions.append(
                1 / (jnp.sqrt(2 * jnp.pi)) * jnp.exp(1j * n * (2 * jnp.pi * phi_vals))
            )
        basis_functions = jnp.array(basis_functions)

        # transform to better diagonal basis
        basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis(
            basis_functions
        )

        # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis
        # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix
        phase_correction_factors = (1j ** (jnp.arange(0, self.N_pre_diag))).reshape(
            self.N_pre_diag, 1
        )  # TODO: review why these are needed...
        wavefunctions = basis_functions_in_H_eigenbasis * phase_correction_factors
        return wavefunctions

    @abstractmethod
    def potential(self, phi):
        """Return potential energy as a function of phi."""

    def plot_wavefunctions(self, phi_vals, max_n=None, which=None, ax=None, mode="abs"):
        if self.basis == BasisTypes.fock:
            _calculate_wavefunctions = self._calculate_wavefunctions_fock
        elif self.basis == BasisTypes.charge:
            _calculate_wavefunctions = self._calculate_wavefunctions_charge
        else:
            raise NotImplementedError(
                f"The {self.basis} is not yet supported for plotting wavefunctions."
            )

        """Plot wavefunctions at phi_exts."""
        wavefunctions = _calculate_wavefunctions(phi_vals)
        energy_levels = self.eig_systems["vals"][: self.N]

        potential = self.potential(phi_vals)

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.5), dpi=1000)
        else:
            fig = ax.get_figure()

        min_val = None
        max_val = None

        assert max_n is None or which is None, "Can't specify both max_n and which"

        max_n = self.N if max_n is None else max_n
        levels = range(max_n) if which is None else which

        for n in levels:
            if mode == "abs":
                wf_vals = jnp.abs(wavefunctions[n, :]) ** 2
            elif mode == "real":
                wf_vals = wavefunctions[n, :].real
            elif mode == "imag":
                wf_vals = wavefunctions[n, :].imag

            wf_vals += energy_levels[n]
            curr_min_val = min(wf_vals)
            curr_max_val = max(wf_vals)

            if min_val is None or curr_min_val < min_val:
                min_val = curr_min_val

            if max_val is None or curr_max_val > max_val:
                max_val = curr_max_val

            ax.plot(
                phi_vals, wf_vals, label=f"$|${n}$\\rangle$", linestyle="-", linewidth=1
            )
            ax.fill_between(phi_vals, energy_levels[n], wf_vals, alpha=0.5)

        ax.plot(
            phi_vals,
            potential,
            label="potential",
            color="black",
            linestyle="-",
            linewidth=1,
        )

        ax.set_ylim([min_val - 1, max_val + 1])

        ax.set_xlabel(r"$\Phi/\Phi_0$")
        ax.set_ylabel(r"Energy [GHz]")

        if mode == "abs":
            title_str = r"$|\psi_n(\Phi)|^2$"
        elif mode == "real":
            title_str = r"Re($\psi_n(\Phi)$)"
        elif mode == "imag":
            title_str = r"Im($\psi_n(\Phi)$)"

        ax.set_title(f"{title_str}")

        plt.legend(fontsize=6)
        fig.tight_layout()

        return ax

phi_zpf() abstractmethod

Return Phase ZPF.

Source code in jaxquantum/devices/superconducting/flux_base.py
18
19
20
@abstractmethod
def phi_zpf(self):
    """Return Phase ZPF."""

potential(phi) abstractmethod

Return potential energy as a function of phi.

Source code in jaxquantum/devices/superconducting/flux_base.py
70
71
72
@abstractmethod
def potential(self, phi):
    """Return potential energy as a function of phi."""