Coverage for jaxquantum/devices/superconducting/transmon.py: 0%
105 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Transmon."""
3from flax import struct
4from jax import config
6import jax.numpy as jnp
8from jaxquantum.devices.base.base import BasisTypes, HamiltonianTypes
9from jaxquantum.devices.superconducting.flux_base import FluxDevice
10from jaxquantum.core.operators import identity, destroy, create
11from jaxquantum.core.conversions import jnp2jqt
13config.update("jax_enable_x64", True)
16@struct.dataclass
17class Transmon(FluxDevice):
18 """
19 Transmon Device.
20 """
22 DEFAULT_BASIS = BasisTypes.charge
23 DEFAULT_HAMILTONIAN = HamiltonianTypes.full
25 @classmethod
26 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
27 """This can be overridden by subclasses."""
28 if hamiltonian == HamiltonianTypes.linear:
29 assert basis == BasisTypes.fock, (
30 "Linear Hamiltonian only works with Fock basis."
31 )
32 elif hamiltonian == HamiltonianTypes.truncated:
33 assert basis == BasisTypes.fock, (
34 "Truncated Hamiltonian only works with Fock basis."
35 )
36 elif hamiltonian == HamiltonianTypes.full:
37 assert basis in [BasisTypes.charge, BasisTypes.single_charge], (
38 "Full Hamiltonian only works with Cooper pair charge or single-electron charge bases."
39 )
41 # Set the gate offset charge to zero if not provided
42 if "ng" not in params:
43 params["ng"] = 0.0
45 assert (N_pre_diag - 1) % 2 == 0, "N_pre_diag must be odd."
47 def common_ops(self):
48 """Written in the specified basis."""
50 ops = {}
52 N = self.N_pre_diag
54 if self.basis == BasisTypes.fock:
55 ops["id"] = identity(N)
56 ops["a"] = destroy(N)
57 ops["a_dag"] = create(N)
58 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"])
59 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])
61 elif self.basis == BasisTypes.charge:
62 """
63 Here H = 4 * Ec (n - ng)² - Ej cos(φ) in the Cooper pair charge basis.
64 """
65 ops["id"] = identity(N)
66 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1)))
67 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1)))
68 n_max = (N - 1) // 2
69 ops["n"] = jnp2jqt(jnp.diag(jnp.arange(-n_max, n_max + 1)))
71 n_minus_ng_array = jnp.arange(-n_max, n_max + 1) - self.params[
72 "ng"
73 ] * jnp.ones(N)
74 ops["H_charge"] = jnp2jqt(
75 jnp.diag(4 * self.params["Ec"] * n_minus_ng_array**2)
76 )
78 elif self.basis == BasisTypes.single_charge:
79 """
80 Here H = Ec (n - 2ng)² - Ej cos(φ) in the single-electron charge basis. Using Eq. (5.36) of Kyle Serniak's
81 thesis, we have H = Ec ∑ₙ(n - 2*ng) |n⟩⟨n| - Ej/2 * ∑ₙ|n⟩⟨n+2| + h.c where n counts the number of electrons,
82 not Cooper pairs. Note, we use 2ng instead of ng to match the gate offset charge convention of the transmon
83 (as done in Kyle's thesis).
84 """
85 n_max = (N - 1) // 2
87 ops["id"] = identity(N)
88 ops["cos(φ)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=2) + jnp.eye(N, k=-2)))
89 ops["sin(φ)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=2) - jnp.eye(N, k=-2)))
90 ops["cos(φ/2)"] = 0.5 * (jnp2jqt(jnp.eye(N, k=1) + jnp.eye(N, k=-1)))
91 ops["sin(φ/2)"] = 0.5j * (jnp2jqt(jnp.eye(N, k=1) - jnp.eye(N, k=-1)))
92 ops["n"] = jnp2jqt(jnp.diag(jnp.arange(-n_max, n_max + 1)))
94 n_minus_ng_array = jnp.arange(-n_max, n_max + 1) - 2 * self.params[
95 "ng"
96 ] * jnp.ones(N)
97 ops["H_charge"] = jnp2jqt(jnp.diag(self.params["Ec"] * n_minus_ng_array**2))
99 return ops
101 @property
102 def Ej(self):
103 return self.params["Ej"]
105 def phi_zpf(self):
106 """Return Phase ZPF."""
107 return (2 * self.params["Ec"] / self.Ej) ** (0.25)
109 def n_zpf(self):
110 """Return Charge ZPF."""
111 return (self.Ej / (32 * self.params["Ec"])) ** (0.25)
113 def get_linear_ω(self):
114 """Get frequency of linear terms."""
115 return jnp.sqrt(8 * self.params["Ec"] * self.Ej)
117 def get_H_linear(self):
118 """Return linear terms in H."""
119 w = self.get_linear_ω()
120 return w * self.original_ops["a_dag"] @ self.original_ops["a"]
122 def get_H_full(self):
123 """Return full H in specified basis."""
124 return self.original_ops["H_charge"] - self.Ej * self.original_ops["cos(φ)"]
126 def get_H_truncated(self):
127 """Return truncated H in specified basis."""
128 phi_op = self.original_ops["phi"]
129 fourth_order_term = -(1 / 24) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op
130 sixth_order_term = (
131 (1 / 720) * self.Ej * phi_op @ phi_op @ phi_op @ phi_op @ phi_op @ phi_op
132 )
133 return self.get_H_linear() + fourth_order_term + sixth_order_term
135 def _get_H_in_original_basis(self):
136 """This returns the Hamiltonian in the original specified basis. This can be overridden by subclasses."""
138 if self.hamiltonian == HamiltonianTypes.linear:
139 return self.get_H_linear()
140 elif self.hamiltonian == HamiltonianTypes.full:
141 return self.get_H_full()
142 elif self.hamiltonian == HamiltonianTypes.truncated:
143 return self.get_H_truncated()
145 def potential(self, phi):
146 """Return potential energy for a given phi."""
147 if self.hamiltonian == HamiltonianTypes.linear:
148 return 0.5 * self.Ej * (2 * jnp.pi * phi) ** 2
149 elif self.hamiltonian == HamiltonianTypes.full:
150 return -self.Ej * jnp.cos(2 * jnp.pi * phi)
151 elif self.hamiltonian == HamiltonianTypes.truncated:
152 phi_scaled = 2 * jnp.pi * phi
153 second_order = 0.5 * self.Ej * phi_scaled**2
154 fourth_order = -(1 / 24) * self.Ej * phi_scaled**4
155 sixth_order = (1 / 720) * self.Ej * phi_scaled**6
156 return second_order + fourth_order + sixth_order
158 def calculate_wavefunctions(self, phi_vals):
159 """Calculate wavefunctions at phi_exts."""
161 if self.basis == BasisTypes.fock:
162 return super().calculate_wavefunctions(phi_vals)
163 elif self.basis == BasisTypes.single_charge:
164 raise NotImplementedError(
165 "Wavefunctions for single charge basis not yet implemented."
166 )
167 elif self.basis == BasisTypes.charge:
168 phi_vals = jnp.array(phi_vals)
170 n_labels = jnp.diag(self.original_ops["n"].data)
172 wavefunctions = []
173 for nj in range(self.N_pre_diag):
174 wavefunction = []
175 for phi in phi_vals:
176 wavefunction.append(
177 (1j**nj / jnp.sqrt(2 * jnp.pi))
178 * jnp.sum(
179 self.eig_systems["vecs"][:, nj]
180 * jnp.exp(1j * phi * n_labels)
181 )
182 )
183 wavefunctions.append(jnp.array(wavefunction))
184 return jnp.array(wavefunctions)