Coverage for jaxquantum/devices/superconducting/fluxonium.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Fluxonium.""" 

2 

3from flax import struct 

4from jax import config 

5import jaxquantum as jqt 

6import jax.numpy as jnp 

7 

8from jaxquantum.devices.superconducting.flux_base import FluxDevice 

9from jaxquantum.devices.base.base import HamiltonianTypes 

10 

11config.update("jax_enable_x64", True) 

12 

13 

14@struct.dataclass 

15class Fluxonium(FluxDevice): 

16 """ 

17 Fluxonium Device. 

18 """ 

19 

20 def common_ops(self): 

21 """Written in the linear basis.""" 

22 ops = {} 

23 

24 N = self.N_pre_diag 

25 ops["id"] = jqt.identity(N) 

26 ops["a"] = jqt.destroy(N) 

27 ops["a_dag"] = jqt.create(N) 

28 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"]) 

29 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"]) 

30 

31 ops["cos(φ/2)"] = jqt.cosm(ops["phi"] / 2) 

32 ops["sin(φ/2)"] = jqt.sinm(ops["phi"] / 2) 

33 

34 return ops 

35 

36 def n_zpf(self): 

37 n_zpf = (self.params["El"] / (32.0 * self.params["Ec"])) ** (0.25) 

38 return n_zpf 

39 

40 def phi_zpf(self): 

41 """Return Phase ZPF.""" 

42 return (2 * self.params["Ec"] / self.params["El"]) ** (0.25) 

43 

44 def get_linear_ω(self): 

45 """Get frequency of linear terms.""" 

46 return jnp.sqrt(8 * self.params["Ec"] * self.params["El"]) 

47 

48 def get_H_linear(self): 

49 """Return linear terms in H.""" 

50 w = self.get_linear_ω() 

51 return w * ( 

52 self.linear_ops["a_dag"] @ self.linear_ops["a"] 

53 + 0.5 * self.linear_ops["id"] 

54 ) 

55 

56 def get_H_full(self): 

57 """Return full H in linear basis.""" 

58 

59 phi_op = self.linear_ops["phi"] 

60 return self.get_H_linear() + self.get_H_nonlinear(phi_op) 

61 

62 def get_H_nonlinear(self, phi_op): 

63 op_cos_phi = jqt.cosm(phi_op) 

64 op_sin_phi = jqt.sinm(phi_op) 

65 

66 phi_ext = self.params["phi_ext"] 

67 Hcos = op_cos_phi * jnp.cos(2.0 * jnp.pi * phi_ext) + op_sin_phi * jnp.sin( 

68 2.0 * jnp.pi * phi_ext 

69 ) 

70 H_nl = -self.params["Ej"] * Hcos 

71 return H_nl 

72 

73 def potential(self, phi): 

74 """Return potential energy for a given phi.""" 

75 phi_ext = self.params["phi_ext"] 

76 V_linear = 0.5 * self.params["El"] * (2 * jnp.pi * phi) ** 2 

77 

78 if self.hamiltonian == HamiltonianTypes.linear: 

79 return V_linear 

80 

81 V_nonlinear = -self.params["Ej"] * jnp.cos(2.0 * jnp.pi * (phi - phi_ext)) 

82 if self.hamiltonian == HamiltonianTypes.full: 

83 return V_linear + V_nonlinear