Coverage for jaxquantum/devices/superconducting/flux_base.py: 0%
83 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Flux base device."""
3from abc import abstractmethod
5from flax import struct
6from jax import config
7import jax.numpy as jnp
8import matplotlib.pyplot as plt
10from jaxquantum.devices.common.utils import harm_osc_wavefunction
11from jaxquantum.devices.base.base import Device, BasisTypes
13config.update("jax_enable_x64", True)
16@struct.dataclass
17class FluxDevice(Device):
18 @abstractmethod
19 def phi_zpf(self):
20 """Return Phase ZPF."""
22 def _calculate_wavefunctions_fock(self, phi_vals):
23 """Calculate wavefunctions at phi_exts."""
24 phi_osc = self.phi_zpf() * jnp.sqrt(2) # length of oscillator
25 phi_vals = jnp.array(phi_vals)
27 # calculate basis functions
28 basis_functions = []
29 for n in range(self.N_pre_diag):
30 basis_functions.append(
31 harm_osc_wavefunction(n, phi_vals, jnp.real(phi_osc))
32 )
33 basis_functions = jnp.array(basis_functions)
35 # transform to better diagonal basis
36 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis(
37 basis_functions
38 )
40 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis
41 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix
42 wavefunctions = basis_functions_in_H_eigenbasis
43 return wavefunctions
45 def _calculate_wavefunctions_charge(self, phi_vals):
46 phi_vals = jnp.array(phi_vals)
48 # calculate basis functions
49 basis_functions = []
50 n_max = (self.N_pre_diag - 1) // 2
51 for n in jnp.arange(-n_max, n_max + 1):
52 basis_functions.append(
53 1 / (jnp.sqrt(2 * jnp.pi)) * jnp.exp(1j * n * (2 * jnp.pi * phi_vals))
54 )
55 basis_functions = jnp.array(basis_functions)
57 # transform to better diagonal basis
58 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis(
59 basis_functions
60 )
62 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis
63 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix
64 phase_correction_factors = (1j ** (jnp.arange(0, self.N_pre_diag))).reshape(
65 self.N_pre_diag, 1
66 ) # TODO: review why these are needed...
67 wavefunctions = basis_functions_in_H_eigenbasis * phase_correction_factors
68 return wavefunctions
70 @abstractmethod
71 def potential(self, phi):
72 """Return potential energy as a function of phi."""
74 def plot_wavefunctions(self, phi_vals, max_n=None, which=None, ax=None, mode="abs"):
75 if self.basis == BasisTypes.fock:
76 _calculate_wavefunctions = self._calculate_wavefunctions_fock
77 elif self.basis == BasisTypes.charge:
78 _calculate_wavefunctions = self._calculate_wavefunctions_charge
79 else:
80 raise NotImplementedError(
81 f"The {self.basis} is not yet supported for plotting wavefunctions."
82 )
84 """Plot wavefunctions at phi_exts."""
85 wavefunctions = _calculate_wavefunctions(phi_vals)
86 energy_levels = self.eig_systems["vals"][: self.N]
88 potential = self.potential(phi_vals)
90 if ax is None:
91 fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.5), dpi=1000)
92 else:
93 fig = ax.get_figure()
95 min_val = None
96 max_val = None
98 assert max_n is None or which is None, "Can't specify both max_n and which"
100 max_n = self.N if max_n is None else max_n
101 levels = range(max_n) if which is None else which
103 for n in levels:
104 if mode == "abs":
105 wf_vals = jnp.abs(wavefunctions[n, :]) ** 2
106 elif mode == "real":
107 wf_vals = wavefunctions[n, :].real
108 elif mode == "imag":
109 wf_vals = wavefunctions[n, :].imag
111 wf_vals += energy_levels[n]
112 curr_min_val = min(wf_vals)
113 curr_max_val = max(wf_vals)
115 if min_val is None or curr_min_val < min_val:
116 min_val = curr_min_val
118 if max_val is None or curr_max_val > max_val:
119 max_val = curr_max_val
121 ax.plot(
122 phi_vals, wf_vals, label=f"$|${n}$\\rangle$", linestyle="-", linewidth=1
123 )
124 ax.fill_between(phi_vals, energy_levels[n], wf_vals, alpha=0.5)
126 ax.plot(
127 phi_vals,
128 potential,
129 label="potential",
130 color="black",
131 linestyle="-",
132 linewidth=1,
133 )
135 ax.set_ylim([min_val - 1, max_val + 1])
137 ax.set_xlabel(r"$\Phi/\Phi_0$")
138 ax.set_ylabel(r"Energy [GHz]")
140 if mode == "abs":
141 title_str = r"$|\psi_n(\Phi)|^2$"
142 elif mode == "real":
143 title_str = r"Re($\psi_n(\Phi)$)"
144 elif mode == "imag":
145 title_str = r"Im($\psi_n(\Phi)$)"
147 ax.set_title(f"{title_str}")
149 plt.legend(fontsize=6)
150 fig.tight_layout()
152 return ax