Coverage for jaxquantum/devices/superconducting/flux_base.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Flux base device.""" 

2 

3from abc import abstractmethod 

4 

5from flax import struct 

6from jax import config 

7import jax.numpy as jnp 

8import matplotlib.pyplot as plt 

9 

10from jaxquantum.devices.common.utils import harm_osc_wavefunction 

11from jaxquantum.devices.base.base import Device, BasisTypes 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16@struct.dataclass 

17class FluxDevice(Device): 

18 @abstractmethod 

19 def phi_zpf(self): 

20 """Return Phase ZPF.""" 

21 

22 def _calculate_wavefunctions_fock(self, phi_vals): 

23 """Calculate wavefunctions at phi_exts.""" 

24 phi_osc = self.phi_zpf() * jnp.sqrt(2) # length of oscillator 

25 phi_vals = jnp.array(phi_vals) 

26 

27 # calculate basis functions 

28 basis_functions = [] 

29 for n in range(self.N_pre_diag): 

30 basis_functions.append( 

31 harm_osc_wavefunction(n, phi_vals, jnp.real(phi_osc)) 

32 ) 

33 basis_functions = jnp.array(basis_functions) 

34 

35 # transform to better diagonal basis 

36 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis( 

37 basis_functions 

38 ) 

39 

40 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis 

41 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix 

42 wavefunctions = basis_functions_in_H_eigenbasis 

43 return wavefunctions 

44 

45 def _calculate_wavefunctions_charge(self, phi_vals): 

46 phi_vals = jnp.array(phi_vals) 

47 

48 # calculate basis functions 

49 basis_functions = [] 

50 n_labels = jnp.diag(self.original_ops["n"].data) 

51 for n in n_labels: 

52 basis_functions.append( 

53 1 / (jnp.sqrt(2 * jnp.pi)) * jnp.exp(1j * n * (2 * jnp.pi * -1 * phi_vals)) # Added a -1 to work with the SNAIL 

54 ) 

55 basis_functions = jnp.array(basis_functions) 

56 

57 # transform to better diagonal basis 

58 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis( 

59 basis_functions 

60 ) 

61 

62 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis 

63 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix 

64 num_eigenstates = basis_functions_in_H_eigenbasis.shape[0] 

65 phase_correction_factors = (1j ** (jnp.arange(0, num_eigenstates))).reshape( 

66 num_eigenstates, 1 

67 ) # TODO: review why these are needed... 

68 wavefunctions = basis_functions_in_H_eigenbasis * phase_correction_factors 

69 return wavefunctions 

70 

71 @abstractmethod 

72 def potential(self, phi): 

73 """Return potential energy as a function of phi.""" 

74 

75 def plot_wavefunctions(self, phi_vals, max_n=None, which=None, ax=None, mode="abs", ylim=None, y_scale_factor=1, zero_potential=False): 

76 if self.basis == BasisTypes.fock: 

77 _calculate_wavefunctions = self._calculate_wavefunctions_fock 

78 elif self.basis == BasisTypes.charge: 

79 _calculate_wavefunctions = self._calculate_wavefunctions_charge 

80 else: 

81 raise NotImplementedError( 

82 f"The {self.basis} is not yet supported for plotting wavefunctions." 

83 ) 

84 

85 """Plot wavefunctions at phi_exts.""" 

86 wavefunctions = _calculate_wavefunctions(phi_vals) 

87 energy_levels = self.eig_systems["vals"][: self.N] 

88 

89 potential = self.potential(phi_vals) 

90 

91 min_potential = 0 if not zero_potential else jnp.min(potential) 

92 if ax is None: 

93 fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.5), dpi=1000) 

94 else: 

95 fig = ax.get_figure() 

96 

97 min_val = None 

98 max_val = None 

99 

100 assert max_n is None or which is None, "Can't specify both max_n and which" 

101 

102 max_n = self.N if max_n is None else max_n 

103 levels = range(max_n) if which is None else which 

104 

105 for n in levels: 

106 if mode == "abs": 

107 wf_vals = jnp.abs(wavefunctions[n, :]) ** 2 

108 elif mode == "real": 

109 wf_vals = wavefunctions[n, :].real 

110 elif mode == "imag": 

111 wf_vals = wavefunctions[n, :].imag 

112 

113 wf_vals += energy_levels[n] 

114 curr_min_val = min(wf_vals) 

115 curr_max_val = max(wf_vals) 

116 

117 if min_val is None or curr_min_val < min_val: 

118 min_val = curr_min_val 

119 

120 if max_val is None or curr_max_val > max_val: 

121 max_val = curr_max_val 

122 

123 ax.plot( 

124 phi_vals, (wf_vals - min_potential)*y_scale_factor, label=f"$|${n}$\\rangle$", linestyle="-", linewidth=1 

125 ) 

126 ax.fill_between(phi_vals, (energy_levels[n] - min_potential)*y_scale_factor, (wf_vals - min_potential)*y_scale_factor, alpha=0.5) 

127 

128 ax.plot( 

129 phi_vals, 

130 (potential - min_potential)*y_scale_factor, 

131 label="potential", 

132 color="black", 

133 linestyle="-", 

134 linewidth=1, 

135 ) 

136 

137 ylim = ylim if ylim is not None else [jnp.min(jnp.array([min_val - 1 - min_potential, jnp.min(potential) - min_potential]))*y_scale_factor, (max_val + 1 - min_potential)*y_scale_factor] 

138 ax.set_ylim(ylim) 

139 ax.set_xlabel(r"$\varphi/2\pi$") 

140 ax.set_ylabel(r"Energy [GHz]") 

141 

142 if mode == "abs": 

143 title_str = r"$|\psi_n(\Phi)|^2$" 

144 elif mode == "real": 

145 title_str = r"Re($\psi_n(\Phi)$)" 

146 elif mode == "imag": 

147 title_str = r"Im($\psi_n(\Phi)$)" 

148 

149 ax.set_title(f"{title_str}") 

150 

151 ax.legend(fontsize='xx-small') 

152 fig.tight_layout() 

153 

154 return ax