Coverage for jaxquantum/devices/superconducting/fluxonium.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Fluxonium.""" 

2 

3from flax import struct 

4from jax import config 

5import jax.numpy as jnp 

6 

7from jaxquantum.devices.superconducting.flux_base import FluxDevice 

8from jaxquantum.devices.base.base import HamiltonianTypes 

9from jaxquantum.core.operators import identity, destroy, create 

10from jaxquantum.core import cosm, sinm 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15@struct.dataclass 

16class Fluxonium(FluxDevice): 

17 """ 

18 Fluxonium Device. 

19 """ 

20 

21 def common_ops(self): 

22 """Written in the linear basis.""" 

23 ops = {} 

24 

25 N = self.N_pre_diag 

26 ops["id"] = identity(N) 

27 ops["a"] = destroy(N) 

28 ops["a_dag"] = create(N) 

29 ops["phi"] = self.phi_zpf() * (ops["a"] + ops["a_dag"]) 

30 ops["n"] = 1j * self.n_zpf() * (ops["a_dag"] - ops["a"]) 

31 

32 ops["cos(φ/2)"] = cosm(ops["phi"] / 2) 

33 ops["sin(φ/2)"] = sinm(ops["phi"] / 2) 

34 

35 return ops 

36 

37 def n_zpf(self): 

38 n_zpf = (self.params["El"] / (32.0 * self.params["Ec"])) ** (0.25) 

39 return n_zpf 

40 

41 def phi_zpf(self): 

42 """Return Phase ZPF.""" 

43 return (2 * self.params["Ec"] / self.params["El"]) ** (0.25) 

44 

45 def get_linear_ω(self): 

46 """Get frequency of linear terms.""" 

47 return jnp.sqrt(8 * self.params["Ec"] * self.params["El"]) 

48 

49 def get_H_linear(self): 

50 """Return linear terms in H.""" 

51 w = self.get_linear_ω() 

52 return w * ( 

53 self.linear_ops["a_dag"] @ self.linear_ops["a"] 

54 + 0.5 * self.linear_ops["id"] 

55 ) 

56 

57 def get_H_full(self): 

58 """Return full H in linear basis.""" 

59 

60 phi_op = self.linear_ops["phi"] 

61 return self.get_H_linear() + self.get_H_nonlinear(phi_op) 

62 

63 def get_H_nonlinear(self, phi_op): 

64 op_cos_phi = cosm(phi_op) 

65 op_sin_phi = sinm(phi_op) 

66 

67 phi_ext = self.params["phi_ext"] 

68 Hcos = op_cos_phi * jnp.cos(2.0 * jnp.pi * phi_ext) + op_sin_phi * jnp.sin( 

69 2.0 * jnp.pi * phi_ext 

70 ) 

71 H_nl = -self.params["Ej"] * Hcos 

72 return H_nl 

73 

74 def potential(self, phi): 

75 """Return potential energy for a given phi.""" 

76 phi_ext = self.params["phi_ext"] 

77 V_linear = 0.5 * self.params["El"] * (2 * jnp.pi * phi) ** 2 

78 

79 if self.hamiltonian == HamiltonianTypes.linear: 

80 return V_linear 

81 

82 V_nonlinear = -self.params["Ej"] * jnp.cos(2.0 * jnp.pi * (phi - phi_ext)) 

83 if self.hamiltonian == HamiltonianTypes.full: 

84 return V_linear + V_nonlinear