Coverage for jaxquantum/devices/superconducting/drive.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Base Drive.""" 

2 

3from abc import ABC 

4from typing import Dict 

5 

6from flax import struct 

7from jax import config 

8import jax.numpy as jnp 

9 

10from jaxquantum.core.qarray import Qarray 

11from jaxquantum.core.conversions import jnp2jqt 

12 

13config.update("jax_enable_x64", True) 

14 

15 

16@struct.dataclass 

17class Drive(ABC): 

18 N: int = struct.field(pytree_node=False) 

19 ωd: float 

20 _label: int = struct.field(pytree_node=False) 

21 

22 @classmethod 

23 def create(cls, M_max, ωd, label=0): 

24 cls.M_max = M_max 

25 N = 2 * M_max + 1 

26 return cls(N, ωd, label) 

27 

28 @property 

29 def label(self): 

30 return self.__class__.__name__ + str(self._label) 

31 

32 @property 

33 def ops(self): 

34 return self.common_ops() 

35 

36 def common_ops(self) -> Dict[str, Qarray]: 

37 ops = {} 

38 

39 M_max = self.M_max 

40 

41 # Construct M = ∑ₘ m|m><m| operator in drive charge basis 

42 ops["M"] = jnp2jqt(jnp.diag(jnp.arange(-M_max, M_max + 1))) 

43 

44 # Construct Id = ∑ₘ|m><m| in the drive charge basis 

45 ops["id"] = jnp2jqt(jnp.identity(2 * M_max + 1)) 

46 

47 # Construct M₊ ≡ exp(iθ) and M₋ ≡ exp(-iθ) operators for drive 

48 ops["M-"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=1)) 

49 ops["M+"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=-1)) 

50 

51 # Construct cos(θ) ≡ 1/2 * [M₊ + M₋] = 1/2 * ∑ₘ|m+1><m| + h.c 

52 ops["cos(θ)"] = 0.5 * (ops["M+"] + ops["M-"]) 

53 

54 # Construct sin(θ) ≡ -i/2 * [M₊ - M₋] = -i/2 * ∑ₘ|m+1><m| + h.c 

55 ops["sin(θ)"] = -0.5j * (ops["M+"] - ops["M-"]) 

56 

57 # Construct more general drive operators cos(kθ) and sin(kθ) 

58 for k in range(2, M_max + 1): 

59 ops[f"M_+{k}"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=-k)) 

60 ops[f"M_-{k}"] = jnp2jqt(jnp.eye(2 * M_max + 1, k=k)) 

61 ops[f"cos({k}θ)"] = 0.5 * (ops[f"M_+{k}"] + ops[f"M_-{k}"]) 

62 ops[f"sin({k}θ)"] = -0.5j * (ops[f"M_+{k}"] - ops[f"M_-{k}"]) 

63 

64 return ops 

65 

66 ############################################################# 

67 

68 def get_H(self): 

69 """ 

70 Bare "drive" Hamiltonian (ωd * M) in the extended Hilbert space. 

71 """ 

72 return self.ωd * self.ops["M"]