Coverage for jaxquantum/devices/superconducting/kno.py: 0%
32 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Kerr Nonlinear Oscillator"""
3from flax import struct
4from jax import config
6import jax.numpy as jnp
8from jaxquantum.devices.base.base import Device, BasisTypes, HamiltonianTypes
9from jaxquantum.core.operators import identity, destroy, create
11config.update("jax_enable_x64", True)
14@struct.dataclass
15class KNO(Device):
16 """
17 Kerr Nonlinear Oscillator Device.
18 """
20 @classmethod
21 def param_validation(cls, N, N_pre_diag, params, hamiltonian, basis):
22 """This can be overridden by subclasses."""
23 assert basis == BasisTypes.fock, (
24 "Kerr Nonlinear Oscillator must be defined in the Fock basis."
25 )
26 assert hamiltonian == HamiltonianTypes.full, (
27 "Kerr Nonlinear Oscillator uses a full Hamiltonian."
28 )
29 assert "ω" in params and "α" in params, (
30 "Kerr Nonlinear Oscillator requires frequency 'ω' and anharmonicity 'α' as parameters."
31 )
33 def common_ops(self):
34 ops = {}
36 N = self.N
37 ops["id"] = identity(N)
38 ops["a"] = destroy(N)
39 ops["a_dag"] = create(N)
40 ops["phi"] = (ops["a"] + ops["a_dag"]) / jnp.sqrt(2)
41 ops["n"] = 1j * (ops["a_dag"] - ops["a"]) / jnp.sqrt(2)
42 return ops
44 def get_linear_ω(self):
45 """Get frequency of linear terms."""
46 return self.params["ω"]
48 def get_anharm(self):
49 """Get anharmonicity."""
50 return self.params["α"]
52 def get_H_linear(self):
53 """Return linear terms in H."""
54 w = self.get_linear_ω()
55 return w * self.linear_ops["a_dag"] @ self.linear_ops["a"]
57 def get_H_full(self):
58 """Return full H in linear basis."""
59 α = self.get_anharm()
61 return self.get_H_linear() + (α / 2) * (
62 self.linear_ops["a_dag"]
63 @ self.linear_ops["a_dag"]
64 @ self.linear_ops["a"]
65 @ self.linear_ops["a"]
66 )