Coverage for jaxquantum/devices/superconducting/flux_base.py: 0%
86 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Flux base device."""
3from abc import abstractmethod
5from flax import struct
6from jax import config
7import jax.numpy as jnp
8import matplotlib.pyplot as plt
10from jaxquantum.devices.common.utils import harm_osc_wavefunction
11from jaxquantum.devices.base.base import Device, BasisTypes
13config.update("jax_enable_x64", True)
16@struct.dataclass
17class FluxDevice(Device):
18 @abstractmethod
19 def phi_zpf(self):
20 """Return Phase ZPF."""
22 def _calculate_wavefunctions_fock(self, phi_vals):
23 """Calculate wavefunctions at phi_exts."""
24 phi_osc = self.phi_zpf() * jnp.sqrt(2) # length of oscillator
25 phi_vals = jnp.array(phi_vals)
27 # calculate basis functions
28 basis_functions = []
29 for n in range(self.N_pre_diag):
30 basis_functions.append(
31 harm_osc_wavefunction(n, phi_vals, jnp.real(phi_osc))
32 )
33 basis_functions = jnp.array(basis_functions)
35 # transform to better diagonal basis
36 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis(
37 basis_functions
38 )
40 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis
41 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix
42 wavefunctions = basis_functions_in_H_eigenbasis
43 return wavefunctions
45 def _calculate_wavefunctions_charge(self, phi_vals):
46 phi_vals = jnp.array(phi_vals)
48 # calculate basis functions
49 basis_functions = []
50 n_labels = jnp.diag(self.original_ops["n"].data)
51 for n in n_labels:
52 basis_functions.append(
53 1 / (jnp.sqrt(2 * jnp.pi)) * jnp.exp(1j * n * (2 * jnp.pi * -1 * phi_vals)) # Added a -1 to work with the SNAIL
54 )
55 basis_functions = jnp.array(basis_functions)
57 # transform to better diagonal basis
58 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis(
59 basis_functions
60 )
62 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis
63 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix
64 num_eigenstates = basis_functions_in_H_eigenbasis.shape[0]
65 phase_correction_factors = (1j ** (jnp.arange(0, num_eigenstates))).reshape(
66 num_eigenstates, 1
67 ) # TODO: review why these are needed...
68 wavefunctions = basis_functions_in_H_eigenbasis * phase_correction_factors
69 return wavefunctions
71 @abstractmethod
72 def potential(self, phi):
73 """Return potential energy as a function of phi."""
75 def plot_wavefunctions(self, phi_vals, max_n=None, which=None, ax=None, mode="abs", ylim=None, y_scale_factor=1, zero_potential=False):
76 if self.basis == BasisTypes.fock:
77 _calculate_wavefunctions = self._calculate_wavefunctions_fock
78 elif self.basis == BasisTypes.charge:
79 _calculate_wavefunctions = self._calculate_wavefunctions_charge
80 else:
81 raise NotImplementedError(
82 f"The {self.basis} is not yet supported for plotting wavefunctions."
83 )
85 """Plot wavefunctions at phi_exts."""
86 wavefunctions = _calculate_wavefunctions(phi_vals)
87 energy_levels = self.eig_systems["vals"][: self.N]
89 potential = self.potential(phi_vals)
91 min_potential = 0 if not zero_potential else jnp.min(potential)
92 if ax is None:
93 fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.5), dpi=1000)
94 else:
95 fig = ax.get_figure()
97 min_val = None
98 max_val = None
100 assert max_n is None or which is None, "Can't specify both max_n and which"
102 max_n = self.N if max_n is None else max_n
103 levels = range(max_n) if which is None else which
105 for n in levels:
106 if mode == "abs":
107 wf_vals = jnp.abs(wavefunctions[n, :]) ** 2
108 elif mode == "real":
109 wf_vals = wavefunctions[n, :].real
110 elif mode == "imag":
111 wf_vals = wavefunctions[n, :].imag
113 wf_vals += energy_levels[n]
114 curr_min_val = min(wf_vals)
115 curr_max_val = max(wf_vals)
117 if min_val is None or curr_min_val < min_val:
118 min_val = curr_min_val
120 if max_val is None or curr_max_val > max_val:
121 max_val = curr_max_val
123 ax.plot(
124 phi_vals, (wf_vals - min_potential)*y_scale_factor, label=f"$|${n}$\\rangle$", linestyle="-", linewidth=1
125 )
126 ax.fill_between(phi_vals, (energy_levels[n] - min_potential)*y_scale_factor, (wf_vals - min_potential)*y_scale_factor, alpha=0.5)
128 ax.plot(
129 phi_vals,
130 (potential - min_potential)*y_scale_factor,
131 label="potential",
132 color="black",
133 linestyle="-",
134 linewidth=1,
135 )
137 ylim = ylim if ylim is not None else [jnp.min(jnp.array([min_val - 1 - min_potential, jnp.min(potential) - min_potential]))*y_scale_factor, (max_val + 1 - min_potential)*y_scale_factor]
138 ax.set_ylim(ylim)
139 ax.set_xlabel(r"$\varphi/2\pi$")
140 ax.set_ylabel(r"Energy [GHz]")
142 if mode == "abs":
143 title_str = r"$|\psi_n(\Phi)|^2$"
144 elif mode == "real":
145 title_str = r"Re($\psi_n(\Phi)$)"
146 elif mode == "imag":
147 title_str = r"Im($\psi_n(\Phi)$)"
149 ax.set_title(f"{title_str}")
151 ax.legend(fontsize='xx-small')
152 fig.tight_layout()
154 return ax