Coverage for jaxquantum/devices/superconducting/flux_base.py: 0%

83 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Flux base device.""" 

2 

3from abc import abstractmethod 

4 

5from flax import struct 

6from jax import config 

7import jax.numpy as jnp 

8import matplotlib.pyplot as plt 

9 

10from jaxquantum.devices.common.utils import harm_osc_wavefunction 

11from jaxquantum.devices.base.base import Device, BasisTypes 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16@struct.dataclass 

17class FluxDevice(Device): 

18 @abstractmethod 

19 def phi_zpf(self): 

20 """Return Phase ZPF.""" 

21 

22 def _calculate_wavefunctions_fock(self, phi_vals): 

23 """Calculate wavefunctions at phi_exts.""" 

24 phi_osc = self.phi_zpf() * jnp.sqrt(2) # length of oscillator 

25 phi_vals = jnp.array(phi_vals) 

26 

27 # calculate basis functions 

28 basis_functions = [] 

29 for n in range(self.N_pre_diag): 

30 basis_functions.append( 

31 harm_osc_wavefunction(n, phi_vals, jnp.real(phi_osc)) 

32 ) 

33 basis_functions = jnp.array(basis_functions) 

34 

35 # transform to better diagonal basis 

36 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis( 

37 basis_functions 

38 ) 

39 

40 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis 

41 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix 

42 wavefunctions = basis_functions_in_H_eigenbasis 

43 return wavefunctions 

44 

45 def _calculate_wavefunctions_charge(self, phi_vals): 

46 phi_vals = jnp.array(phi_vals) 

47 

48 # calculate basis functions 

49 basis_functions = [] 

50 n_max = (self.N_pre_diag - 1) // 2 

51 for n in jnp.arange(-n_max, n_max + 1): 

52 basis_functions.append( 

53 1 / (jnp.sqrt(2 * jnp.pi)) * jnp.exp(1j * n * (2 * jnp.pi * phi_vals)) 

54 ) 

55 basis_functions = jnp.array(basis_functions) 

56 

57 # transform to better diagonal basis 

58 basis_functions_in_H_eigenbasis = self.get_vec_data_in_H_eigenbasis( 

59 basis_functions 

60 ) 

61 

62 # the below is equivalent to evecs_in_H_eigenbasis @ basis_functions_in_H_eigenbasis 

63 # since evecs in H_eigenbasis is diagonal, i.e. the identity matrix 

64 phase_correction_factors = (1j ** (jnp.arange(0, self.N_pre_diag))).reshape( 

65 self.N_pre_diag, 1 

66 ) # TODO: review why these are needed... 

67 wavefunctions = basis_functions_in_H_eigenbasis * phase_correction_factors 

68 return wavefunctions 

69 

70 @abstractmethod 

71 def potential(self, phi): 

72 """Return potential energy as a function of phi.""" 

73 

74 def plot_wavefunctions(self, phi_vals, max_n=None, which=None, ax=None, mode="abs"): 

75 if self.basis == BasisTypes.fock: 

76 _calculate_wavefunctions = self._calculate_wavefunctions_fock 

77 elif self.basis == BasisTypes.charge: 

78 _calculate_wavefunctions = self._calculate_wavefunctions_charge 

79 else: 

80 raise NotImplementedError( 

81 f"The {self.basis} is not yet supported for plotting wavefunctions." 

82 ) 

83 

84 """Plot wavefunctions at phi_exts.""" 

85 wavefunctions = _calculate_wavefunctions(phi_vals) 

86 energy_levels = self.eig_systems["vals"][: self.N] 

87 

88 potential = self.potential(phi_vals) 

89 

90 if ax is None: 

91 fig, ax = plt.subplots(1, 1, figsize=(3.5, 2.5), dpi=1000) 

92 else: 

93 fig = ax.get_figure() 

94 

95 min_val = None 

96 max_val = None 

97 

98 assert max_n is None or which is None, "Can't specify both max_n and which" 

99 

100 max_n = self.N if max_n is None else max_n 

101 levels = range(max_n) if which is None else which 

102 

103 for n in levels: 

104 if mode == "abs": 

105 wf_vals = jnp.abs(wavefunctions[n, :]) ** 2 

106 elif mode == "real": 

107 wf_vals = wavefunctions[n, :].real 

108 elif mode == "imag": 

109 wf_vals = wavefunctions[n, :].imag 

110 

111 wf_vals += energy_levels[n] 

112 curr_min_val = min(wf_vals) 

113 curr_max_val = max(wf_vals) 

114 

115 if min_val is None or curr_min_val < min_val: 

116 min_val = curr_min_val 

117 

118 if max_val is None or curr_max_val > max_val: 

119 max_val = curr_max_val 

120 

121 ax.plot( 

122 phi_vals, wf_vals, label=f"$|${n}$\\rangle$", linestyle="-", linewidth=1 

123 ) 

124 ax.fill_between(phi_vals, energy_levels[n], wf_vals, alpha=0.5) 

125 

126 ax.plot( 

127 phi_vals, 

128 potential, 

129 label="potential", 

130 color="black", 

131 linestyle="-", 

132 linewidth=1, 

133 ) 

134 

135 ax.set_ylim([min_val - 1, max_val + 1]) 

136 

137 ax.set_xlabel(r"$\Phi/\Phi_0$") 

138 ax.set_ylabel(r"Energy [GHz]") 

139 

140 if mode == "abs": 

141 title_str = r"$|\psi_n(\Phi)|^2$" 

142 elif mode == "real": 

143 title_str = r"Re($\psi_n(\Phi)$)" 

144 elif mode == "imag": 

145 title_str = r"Im($\psi_n(\Phi)$)" 

146 

147 ax.set_title(f"{title_str}") 

148 

149 plt.legend(fontsize=6) 

150 fig.tight_layout() 

151 

152 return ax