Coverage for jaxquantum/devices/superconducting/drive.py: 0%
41 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Base Drive."""
3from abc import ABC
4from typing import Dict
6from flax import struct
7from jax import config
8import jax.numpy as jnp
10from jaxquantum.core.qarray import Qarray
11from jaxquantum.core.conversions import jnp2jqt
13config.update("jax_enable_x64", True)
16@struct.dataclass
17class Drive(ABC):
18 N: int = struct.field(pytree_node=False)
19 ωd: float
20 _label: int = struct.field(pytree_node=False)
22 @classmethod
23 def create(cls, M_max, ωd, label=0):
24 cls.M_max = M_max
25 N = 2 * M_max + 1
26 return cls(N, ωd, label)
28 @property
29 def label(self):
30 return self.__class__.__name__ + str(self._label)
32 @property
33 def ops(self):
34 return self.common_ops()
36 def common_ops(self) -> Dict[str, Qarray]:
37 ops = {}
39 M_max = self.M_max
41 # Construct M = ∑ₘ m|m><m| operator in drive charge basis
42 ops["M"] = jnp2jqt(jnp.diag(jnp.arange(-M_max, M_max + 1)))
44 # Construct Id = ∑ₘ|m><m| in the drive charge basis
45 ops["id"] = jnp2jqt(jnp.identity(2 * M_max + 1))
47 # Construct M₊ ≡ exp(iθ) and M₋ ≡ exp(-iθ) operators for drive
48 ops["M-"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=1))
49 ops["M+"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=-1))
51 # Construct cos(θ) ≡ 1/2 * [M₊ + M₋] = 1/2 * ∑ₘ|m+1><m| + h.c
52 ops["cos(θ)"] = 0.5 * (ops["M+"] + ops["M-"])
54 # Construct sin(θ) ≡ -i/2 * [M₊ - M₋] = -i/2 * ∑ₘ|m+1><m| + h.c
55 ops["sin(θ)"] = -0.5j * (ops["M+"] - ops["M-"])
57 # Construct more general drive operators cos(kθ) and sin(kθ)
58 for k in range(2, M_max + 1):
59 ops[f"M_+{k}"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=-k))
60 ops[f"M_-{k}"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=k))
61 ops[f"cos({k}θ)"] = 0.5 * (ops[f"M_+{k}"] + ops[f"M_-{k}"])
62 ops[f"sin({k}θ)"] = -0.5j * (ops[f"M_+{k}"] - ops[f"M_-{k}"])
64 return ops
66 #############################################################
68 def get_H(self):
69 """
70 Bare "drive" Hamiltonian (ωd * M) in the extended Hilbert space.
71 """
72 return self.ωd * self.ops["M"]