Coverage for jaxquantum/devices/superconducting/snail.py: 0%
88 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Transmon."""
3from flax import struct
4from jax import config
6import jax.numpy as jnp
8from jaxquantum.devices.base.base import BasisTypes, HamiltonianTypes
9from jaxquantum.devices.superconducting.flux_base import FluxDevice
10from jaxquantum.core.operators import identity, destroy, create
11from jaxquantum.core.conversions import jnp2jqt
13config.update("jax_enable_x64", True)
16@struct.dataclass
17class SNAIL(FluxDevice):
18 """
19 SNAIL Device.
20 """
22 DEFAULT_BASIS = BasisTypes.charge
23 DEFAULT_HAMILTONIAN = HamiltonianTypes.full
25 @classmethod
26 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
27 """This can be overridden by subclasses."""
29 assert params["m"] % 1 == 0, "m must be an integer."
30 assert params["m"] >= 2, "m must be greater than or equal to 2."
32 if hamiltonian == HamiltonianTypes.linear:
33 assert basis == BasisTypes.fock, "Linear Hamiltonian only works with Fock basis."
34 elif hamiltonian == HamiltonianTypes.truncated:
35 assert basis == BasisTypes.fock, "Truncated Hamiltonian only works with Fock basis."
36 elif hamiltonian == HamiltonianTypes.full:
37 charge_basis_types = [
38 BasisTypes.charge
39 ]
40 assert basis in charge_basis_types, "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases."
42 assert (N_pre_diag - 1) % 2 * (params["m"]) == 0, "(N_pre_diag - 1)/2 must be divisible by m."
44 # Set the gate offset charge to zero if not provided
45 if "ng" not in params:
46 params["ng"] = 0.0
48 def common_ops(self):
49 """ Written in the specified basis. """
51 ops = {}
53 N = self.N_pre_diag
55 if self.basis == BasisTypes.fock:
56 ops["id"] = identity(N)
57 ops["a"] = destroy(N)
58 ops["a_dag"] = create(N)
59 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"])
60 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])
62 elif self.basis == BasisTypes.charge:
63 """
64 Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis.
65 """
66 m = self.params["m"]
67 ops["id"] = identity(N)
68 ops["cos(φ/m)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1)))
69 ops["sin(φ/m)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1)))
70 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=m) + jnp.eye(N, k=-m)))
71 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=m) - jnp.eye(N, k=-m)))
73 n_max = (N - 1) // 2
74 n_array = jnp.arange(-n_max, n_max + 1) / self.params["m"]
75 ops["n"] = jnp2jqt(jnp.diag(n_array))
77 n_minus_ng_array = n_array - self.params["ng"] * jnp.ones(N)
78 ops["H_charge"] = jnp2jqt(jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2))
80 return ops
82 @property
83 def Ej(self):
84 return self.params["Ej"]
86 def phi_zpf(self):
87 """Return Phase ZPF."""
88 return (2 * self.params["Ec"] / self.Ej) ** (0.25)
90 def n_zpf(self):
91 """Return Charge ZPF."""
92 return (self.Ej / (32 * self.params["Ec"])) ** (0.25)
94 def get_linear_ω(self):
95 """Get frequency of linear terms."""
96 return jnp.sqrt(8 * self.params["Ec"] * self.Ej)
98 def get_H_linear(self):
99 """Return linear terms in H."""
100 w = self.get_linear_ω()
101 return w * self.original_ops["a_dag"] @ self.original_ops["a"]
103 def get_H_full(self):
104 """Return full H in specified basis."""
106 α = self.params["alpha"]
107 m = self.params["m"]
108 phi_ext = self.params["phi_ext"]
109 Ej = self.Ej
111 H_charge = self.original_ops["H_charge"]
112 H_inductive = - α * Ej * self.original_ops["cos(φ)"] - m * Ej * (
113 jnp.cos(2 * jnp.pi * phi_ext/m) * self.original_ops["cos(φ/m)"] + jnp.sin(2 * jnp.pi * phi_ext/m) * self.original_ops["sin(φ/m)"]
114 )
115 return H_charge + H_inductive
117 def get_H_truncated(self):
118 """Return truncated H in specified basis."""
119 raise NotImplementedError("Truncated Hamiltonian not implemented for SNAIL.")
120 # phi_op = self.original_ops["phi"]
121 # fourth_order_term = -(1 / 24) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op
122 # sixth_order_term = (1 / 720) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op @ phi_op @ phi_op
123 # return self.get_H_linear() + fourth_order_term + sixth_order_term
125 def _get_H_in_original_basis(self):
126 """ This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses."""
128 if self.hamiltonian == HamiltonianTypes.linear:
129 return self.get_H_linear()
130 elif self.hamiltonian == HamiltonianTypes.full:
131 return self.get_H_full()
132 elif self.hamiltonian == HamiltonianTypes.truncated:
133 return self.get_H_truncated()
135 def potential(self, phi):
136 """Return potential energy for a given phi."""
137 if self.hamiltonian == HamiltonianTypes.linear:
138 return 0.5 * self.Ej * (2 * jnp.pi * phi) ** 2
139 elif self.hamiltonian == HamiltonianTypes.full:
141 α = self.params["alpha"]
142 m = self.params["m"]
143 phi_ext = self.params["phi_ext"]
145 return - α * self.Ej * jnp.cos(2 * jnp.pi * phi) - (
146 m * self.Ej * jnp.cos(2 * jnp.pi * (phi_ext - phi) / m)
147 )
149 elif self.hamiltonian == HamiltonianTypes.truncated:
150 raise NotImplementedError("Truncated potential not implemented for SNAIL.")
151 # phi_scaled = 2 * jnp.pi * phi
152 # second_order = 0.5 * self.Ej * phi_scaled ** 2
153 # fourth_order = -(1 / 24) * self.Ej * phi_scaled ** 4
154 # sixth_order = (1 / 720) * self.Ej * phi_scaled ** 6
155 # return second_order + fourth_order + sixth_order