Skip to content

snail

Transmon.

SNAIL

Bases: FluxDevice

SNAIL Device.

Source code in jaxquantum/devices/superconducting/snail.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@struct.dataclass
class SNAIL(FluxDevice):
    """
    SNAIL Device.
    """

    DEFAULT_BASIS = BasisTypes.charge
    DEFAULT_HAMILTONIAN = HamiltonianTypes.full

    @classmethod
    def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
        """This can be overridden by subclasses."""

        assert params["m"] % 1 == 0, "m must be an integer."
        assert params["m"] >= 2, "m must be greater than or equal to 2."

        if hamiltonian == HamiltonianTypes.linear:
            assert basis == BasisTypes.fock, "Linear Hamiltonian only works with Fock basis."
        elif hamiltonian == HamiltonianTypes.truncated:
            assert basis == BasisTypes.fock, "Truncated Hamiltonian only works with Fock basis."
        elif hamiltonian == HamiltonianTypes.full:
            charge_basis_types = [
                BasisTypes.charge
            ]
            assert basis in charge_basis_types, "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases."

            assert (N_pre_diag - 1) % 2 * (params["m"]) == 0, "(N_pre_diag - 1)/2 must be divisible by m."

        # Set the gate offset charge to zero if not provided
        if "ng" not in params:
            params["ng"] = 0.0

    def common_ops(self):
        """ Written in the specified basis. """

        ops = {}

        N = self.N_pre_diag

        if self.basis == BasisTypes.fock:
            ops["id"] = identity(N)
            ops["a"] = destroy(N)
            ops["a_dag"] = create(N)
            ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"])
            ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])

        elif self.basis == BasisTypes.charge:
            """
            Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis. 
            """
            m = self.params["m"]
            ops["id"] = identity(N)
            ops["cos(φ/m)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1)))
            ops["sin(φ/m)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1)))
            ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=m) + jnp.eye(N, k=-m)))
            ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=m) - jnp.eye(N, k=-m)))

            n_max = (N - 1) // 2
            n_array = jnp.arange(-n_max, n_max + 1) / self.params["m"]
            ops["n"] = jnp2jqt(jnp.diag(n_array))

            n_minus_ng_array = n_array - self.params["ng"] * jnp.ones(N)
            ops["H_charge"] = jnp2jqt(jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2))

        return ops

    @property
    def Ej(self):
        return self.params["Ej"]

    def phi_zpf(self):
        """Return Phase ZPF."""
        return (2 * self.params["Ec"] / self.Ej) ** (0.25)

    def n_zpf(self):
        """Return Charge ZPF."""
        return (self.Ej / (32 * self.params["Ec"])) ** (0.25)

    def get_linear_ω(self):
        """Get frequency of linear terms."""
        return jnp.sqrt(8 * self.params["Ec"] * self.Ej)

    def get_H_linear(self):
        """Return linear terms in H."""
        w = self.get_linear_ω()
        return w * self.original_ops["a_dag"] @ self.original_ops["a"]

    def get_H_full(self):
        """Return full H in specified basis."""

        α = self.params["alpha"]
        m = self.params["m"]
        phi_ext = self.params["phi_ext"]
        Ej = self.Ej

        H_charge = self.original_ops["H_charge"]
        H_inductive = - α * Ej * self.original_ops["cos(φ)"] - m * Ej * (
            jnp.cos(2 * jnp.pi * phi_ext/m) * self.original_ops["cos(φ/m)"] + jnp.sin(2 * jnp.pi * phi_ext/m) * self.original_ops["sin(φ/m)"]
        )
        return H_charge + H_inductive

    def get_H_truncated(self):
        """Return truncated H in specified basis."""
        raise NotImplementedError("Truncated Hamiltonian not implemented for SNAIL.")
        # phi_op = self.original_ops["phi"]  
        # fourth_order_term =  -(1 / 24) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op 
        # sixth_order_term = (1 / 720) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op @ phi_op @ phi_op
        # return self.get_H_linear() + fourth_order_term + sixth_order_term

    def _get_H_in_original_basis(self):
        """ This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses."""

        if self.hamiltonian == HamiltonianTypes.linear:
            return self.get_H_linear()
        elif self.hamiltonian == HamiltonianTypes.full:
            return self.get_H_full()
        elif self.hamiltonian == HamiltonianTypes.truncated:
            return self.get_H_truncated()

    def potential(self, phi):
        """Return potential energy for a given phi."""
        if self.hamiltonian == HamiltonianTypes.linear:
            return 0.5 * self.Ej * (2 * jnp.pi * phi) ** 2
        elif self.hamiltonian == HamiltonianTypes.full:

            α = self.params["alpha"]
            m = self.params["m"]
            phi_ext = self.params["phi_ext"]

            return - α * self.Ej * jnp.cos(2 * jnp.pi * phi) - (
                m * self.Ej * jnp.cos(2 * jnp.pi * (phi_ext - phi) / m)
            )

        elif self.hamiltonian == HamiltonianTypes.truncated:
            raise NotImplementedError("Truncated potential not implemented for SNAIL.")

common_ops()

Written in the specified basis.

Source code in jaxquantum/devices/superconducting/snail.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def common_ops(self):
    """ Written in the specified basis. """

    ops = {}

    N = self.N_pre_diag

    if self.basis == BasisTypes.fock:
        ops["id"] = identity(N)
        ops["a"] = destroy(N)
        ops["a_dag"] = create(N)
        ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"])
        ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])

    elif self.basis == BasisTypes.charge:
        """
        Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis. 
        """
        m = self.params["m"]
        ops["id"] = identity(N)
        ops["cos(φ/m)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1)))
        ops["sin(φ/m)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1)))
        ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=m) + jnp.eye(N, k=-m)))
        ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=m) - jnp.eye(N, k=-m)))

        n_max = (N - 1) // 2
        n_array = jnp.arange(-n_max, n_max + 1) / self.params["m"]
        ops["n"] = jnp2jqt(jnp.diag(n_array))

        n_minus_ng_array = n_array - self.params["ng"] * jnp.ones(N)
        ops["H_charge"] = jnp2jqt(jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2))

    return ops

get_H_full()

Return full H in specified basis.

Source code in jaxquantum/devices/superconducting/snail.py
103
104
105
106
107
108
109
110
111
112
113
114
115
def get_H_full(self):
    """Return full H in specified basis."""

    α = self.params["alpha"]
    m = self.params["m"]
    phi_ext = self.params["phi_ext"]
    Ej = self.Ej

    H_charge = self.original_ops["H_charge"]
    H_inductive = - α * Ej * self.original_ops["cos(φ)"] - m * Ej * (
        jnp.cos(2 * jnp.pi * phi_ext/m) * self.original_ops["cos(φ/m)"] + jnp.sin(2 * jnp.pi * phi_ext/m) * self.original_ops["sin(φ/m)"]
    )
    return H_charge + H_inductive

get_H_linear()

Return linear terms in H.

Source code in jaxquantum/devices/superconducting/snail.py
 98
 99
100
101
def get_H_linear(self):
    """Return linear terms in H."""
    w = self.get_linear_ω()
    return w * self.original_ops["a_dag"] @ self.original_ops["a"]

get_H_truncated()

Return truncated H in specified basis.

Source code in jaxquantum/devices/superconducting/snail.py
117
118
119
def get_H_truncated(self):
    """Return truncated H in specified basis."""
    raise NotImplementedError("Truncated Hamiltonian not implemented for SNAIL.")

get_linear_ω()

Get frequency of linear terms.

Source code in jaxquantum/devices/superconducting/snail.py
94
95
96
def get_linear_ω(self):
    """Get frequency of linear terms."""
    return jnp.sqrt(8 * self.params["Ec"] * self.Ej)

n_zpf()

Return Charge ZPF.

Source code in jaxquantum/devices/superconducting/snail.py
90
91
92
def n_zpf(self):
    """Return Charge ZPF."""
    return (self.Ej / (32 * self.params["Ec"])) ** (0.25)

param_validation(N, N_pre_diag, params, hamiltonian, basis) classmethod

This can be overridden by subclasses.

Source code in jaxquantum/devices/superconducting/snail.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@classmethod
def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
    """This can be overridden by subclasses."""

    assert params["m"] % 1 == 0, "m must be an integer."
    assert params["m"] >= 2, "m must be greater than or equal to 2."

    if hamiltonian == HamiltonianTypes.linear:
        assert basis == BasisTypes.fock, "Linear Hamiltonian only works with Fock basis."
    elif hamiltonian == HamiltonianTypes.truncated:
        assert basis == BasisTypes.fock, "Truncated Hamiltonian only works with Fock basis."
    elif hamiltonian == HamiltonianTypes.full:
        charge_basis_types = [
            BasisTypes.charge
        ]
        assert basis in charge_basis_types, "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases."

        assert (N_pre_diag - 1) % 2 * (params["m"]) == 0, "(N_pre_diag - 1)/2 must be divisible by m."

    # Set the gate offset charge to zero if not provided
    if "ng" not in params:
        params["ng"] = 0.0

phi_zpf()

Return Phase ZPF.

Source code in jaxquantum/devices/superconducting/snail.py
86
87
88
def phi_zpf(self):
    """Return Phase ZPF."""
    return (2 * self.params["Ec"] / self.Ej) ** (0.25)

potential(phi)

Return potential energy for a given phi.

Source code in jaxquantum/devices/superconducting/snail.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def potential(self, phi):
    """Return potential energy for a given phi."""
    if self.hamiltonian == HamiltonianTypes.linear:
        return 0.5 * self.Ej * (2 * jnp.pi * phi) ** 2
    elif self.hamiltonian == HamiltonianTypes.full:

        α = self.params["alpha"]
        m = self.params["m"]
        phi_ext = self.params["phi_ext"]

        return - α * self.Ej * jnp.cos(2 * jnp.pi * phi) - (
            m * self.Ej * jnp.cos(2 * jnp.pi * (phi_ext - phi) / m)
        )

    elif self.hamiltonian == HamiltonianTypes.truncated:
        raise NotImplementedError("Truncated potential not implemented for SNAIL.")