Coverage for jaxquantum/devices/superconducting/truncated_transmon.py: 0%
33 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Transmon."""
3from flax import struct
4from jax import config
6import jax.numpy as jnp
7import jax.scipy as jsp
9from jaxquantum.devices.superconducting.flux_base import FluxDevice
10from jaxquantum.core.operators import identity, destroy, create
12config.update("jax_enable_x64", True)
15@struct.dataclass
16class TruncatedTransmon(FluxDevice):
17 """
18 Transmon Device.
19 """
21 def common_ops(self):
22 """Written in the linear basis."""
24 ops = {}
26 N = self.N_pre_diag
27 ops["id"] = identity(N)
28 ops["a"] = destroy(N)
29 ops["a_dag"] = create(N)
30 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"])
31 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])
32 return ops
34 def phi_zpf(self):
35 """Return Phase ZPF."""
36 return (2 * self.params["Ec"] / self.params["Ej"]) ** (0.25)
38 def n_zpf(self):
39 """Return Charge ZPF."""
40 return (self.params["Ej"] / (32 * self.params["Ec"])) ** (0.25)
42 def get_linear_ω(self):
43 """Get frequency of linear terms."""
44 return jnp.sqrt(8 * self.params["Ec"] * self.params["Ej"])
46 def get_H_linear(self):
47 """Return linear terms in H."""
48 w = self.get_linear_ω()
49 return w * self.linear_ops["a_dag"] @ self.linear_ops["a"]
51 def get_H_full(self):
52 """Return full H in linear basis."""
53 cos_phi_op = (
54 jsp.linalg.expm(1j * self.linear_ops["phi"])
55 + jsp.linalg.expm(-1j * self.linear_ops["phi"])
56 ) / 2
58 H_nl = -self.params["Ej"] * cos_phi_op - self.params[
59 "Ej"
60 ] / 2 * jnp.linalg.matrix_power(self.linear_ops["phi"], 2)
61 return self.get_H_linear() + H_nl
63 def potential(self, phi):
64 """Return potential energy for a given phi."""
65 return -self.params["Ej"] * jnp.cos(2 * jnp.pi * phi)