Coverage for jaxquantum/devices/superconducting/ats.py: 0%
58 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""ATS."""
3from flax import struct
4from jax import config
6import jax.numpy as jnp
8from jaxquantum.devices.superconducting.flux_base import FluxDevice
9from jaxquantum.core.operators import identity, destroy, create
10from jaxquantum.core.qarray import cosm, sinm
12config.update("jax_enable_x64", True)
15@struct.dataclass
16class ATS(FluxDevice):
17 """
18 ATS Device.
19 """
21 def common_ops(self):
22 """Written in the linear basis."""
23 ops = {}
25 N = self.N_pre_diag
26 ops["id"] = identity(N)
27 ops["a"] = destroy(N)
28 ops["a_dag"] = create(N)
29 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"])
30 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"])
31 return ops
33 def phi_zpf(self):
34 """Return Phase ZPF."""
35 return (2 * self.params["Ec"] / self.params["El"]) ** (0.25)
37 def n_zpf(self):
38 """Return Charge ZPF."""
39 return (self.params["El"] / (32 * self.params["Ec"])) ** (0.25)
41 def get_linear_ω(self):
42 """Get frequency of linear terms."""
43 return jnp.sqrt(8 * self.params["El"] * self.params["Ec"])
45 def get_H_linear(self):
46 """Return linear terms in H."""
47 w = self.get_linear_ω()
48 return w * (
49 self.linear_ops["a_dag"] @ self.linear_ops["a"]
50 + 0.5 * self.linear_ops["id"]
51 )
53 @staticmethod
54 def get_H_nonlinear_static(phi_op, Ej, dEj, Ej2, phi_sum, phi_delta):
55 cos_phi_op = cosm(phi_op)
56 sin_phi_op = sinm(phi_op)
58 cos_2phi_op = cos_phi_op @ cos_phi_op - sin_phi_op @ sin_phi_op
59 sin_2phi_op = 2 * cos_phi_op @ sin_phi_op
61 H_nl_Ej = (
62 -2
63 * Ej
64 * (
65 cos_phi_op * jnp.cos(2 * jnp.pi * phi_delta)
66 - sin_phi_op * jnp.sin(2 * jnp.pi * phi_delta)
67 )
68 * jnp.cos(2 * jnp.pi * phi_sum)
69 )
70 H_nl_dEj = (
71 2
72 * dEj
73 * (
74 sin_phi_op * jnp.cos(2 * jnp.pi * phi_delta)
75 + cos_phi_op * jnp.sin(2 * jnp.pi * phi_delta)
76 )
77 * jnp.sin(2 * jnp.pi * phi_sum)
78 )
79 H_nl_Ej2 = (
80 2
81 * Ej2
82 * (
83 cos_2phi_op * jnp.cos(2 * 2 * jnp.pi * phi_delta)
84 - sin_2phi_op * jnp.sin(2 * 2 * jnp.pi * phi_delta)
85 )
86 * jnp.cos(2 * 2 * jnp.pi * phi_sum)
87 )
89 H_nl = H_nl_Ej + H_nl_dEj + H_nl_Ej2
91 # id_op = jqt.identity_like(phi_op)
92 # phi_delta_ext_op = self.params["phi_delta_ext"] * id_op
93 # H_nl_old = - 2 * Ej * jqt.cosm(phi_op + 2 * jnp.pi * phi_delta_ext_op) * jnp.cos(2 * jnp.pi * self.params["phi_sum_ext"])
94 # H_nl_old += 2 * dEj * jqt.sinm(phi_op + 2 * jnp.pi * phi_delta_ext_op) * jnp.sin(2 * jnp.pi * self.params["phi_sum_ext"])
95 # H_nl_old += 2 * Ej2 * jqt.cosm(2*phi_op + 2 * 2 * jnp.pi * phi_delta_ext_op) * jnp.cos(2 * 2 * jnp.pi * self.params["phi_sum_ext"])
97 return H_nl
99 def get_H_nonlinear(self, phi_op):
100 """Return nonlinear terms in H."""
102 Ej = self.params["Ej"]
103 dEj = self.params["dEj"]
104 Ej2 = self.params["Ej2"]
106 phi_sum = self.params["phi_sum_ext"]
107 phi_delta = self.params["phi_delta_ext"]
109 return ATS.get_H_nonlinear_static(phi_op, Ej, dEj, Ej2, phi_sum, phi_delta)
111 def get_H_full(self):
112 """Return full H in linear basis."""
113 phi_b = self.linear_ops["phi"]
114 H_nl = self.get_H_nonlinear(phi_b)
115 H = self.get_H_linear() + H_nl
116 return H
118 def potential(self, phi):
119 """Return potential energy for a given phi."""
121 phi_delta_ext = self.params["phi_delta_ext"]
122 phi_sum_ext = self.params["phi_sum_ext"]
124 V = 0.5 * self.params["El"] * (2 * jnp.pi * phi) ** 2
125 V += (
126 -2
127 * self.params["Ej"]
128 * jnp.cos(2 * jnp.pi * (phi + phi_delta_ext))
129 * jnp.cos(2 * jnp.pi * phi_sum_ext)
130 )
131 V += (
132 2
133 * self.params["dEj"]
134 * jnp.sin(2 * jnp.pi * (phi + phi_delta_ext))
135 * jnp.sin(2 * jnp.pi * phi_sum_ext)
136 )
137 V += (
138 2
139 * self.params["Ej2"]
140 * jnp.cos(2 * 2 * jnp.pi * (phi + phi_delta_ext))
141 * jnp.cos(2 * 2 * jnp.pi * phi_sum_ext)
142 )
144 return V