Coverage for jaxquantum/devices/base/system.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""System.""" 

2 

3from typing import List, Optional, Dict, Any, Union 

4import math 

5 

6from flax import struct 

7from jax import vmap, Array 

8from jax import config 

9 

10import jax.numpy as jnp 

11 

12from jaxquantum.devices.base.base import Device 

13from jaxquantum.devices.superconducting.drive import Drive 

14from jaxquantum.core.qarray import Qarray, tensor 

15from jaxquantum.core.operators import identity 

16 

17config.update("jax_enable_x64", True) 

18 

19 

20def calculate_eig(Ns, H: Qarray): 

21 N_tot = math.prod(Ns) 

22 edxs = jnp.arange(N_tot) 

23 

24 vals, kets = jnp.linalg.eigh(H.data) 

25 kets = kets.T 

26 

27 def calc_quantum_number(edx): 

28 argmax = jnp.argmax(jnp.abs(kets[edx])) 

29 val = vals[edx] # - vals[0] 

30 return val, argmax, kets[edx] 

31 

32 quantum_numbers = vmap(calc_quantum_number)(edxs) 

33 

34 def calc_order(edx): 

35 indx = jnp.argmin(jnp.abs(edx - quantum_numbers[1])) 

36 return quantum_numbers[0][indx], quantum_numbers[2][indx] 

37 

38 Es, kets = vmap(calc_order)(edxs) 

39 

40 kets = jnp.reshape(kets, (N_tot, N_tot, 1)) 

41 kets = Qarray.create(kets) 

42 kets = kets.reshape_qdims(*Ns) 

43 kets = kets.reshape_bdims(*Ns) 

44 

45 return ( 

46 jnp.reshape(Es, Ns), 

47 kets, 

48 ) 

49 

50 

51def promote(op: Qarray, device_num, Ns): 

52 I_ops = [identity(N) for N in Ns] 

53 return tensor(*I_ops[:device_num], op, *I_ops[device_num + 1 :]) 

54 

55 

56@struct.dataclass 

57class System: 

58 Ns: List[int] = struct.field(pytree_node=False) 

59 devices: List[Union[Device, Drive]] 

60 couplings: List[Array] 

61 params: Dict[str, Any] 

62 

63 @classmethod 

64 def create( 

65 cls, 

66 devices: List[Union[Device, Drive]], 

67 couplings: Optional[List[Array]] = None, 

68 params: Optional[Dict[str, Any]] = None, 

69 ): 

70 labels = [device.label for device in devices] 

71 unique_labels = set(labels) 

72 if len(labels) != len(unique_labels): 

73 raise ValueError("Devices must have unique labels.") 

74 

75 Ns = tuple([device.N for device in devices]) 

76 couplings = couplings if couplings is not None else [] 

77 params = params if params is not None else {} 

78 return cls(Ns, devices, couplings, params) 

79 

80 def promote(self, op, device_num): 

81 return promote(op, device_num, self.Ns) 

82 

83 def get_H_bare(self): 

84 H = 0 

85 for j, device in enumerate(self.devices): 

86 H += self.promote(device.get_H(), j) 

87 return H 

88 

89 def get_H_couplings(self): 

90 H = 0 

91 for coupling in self.couplings: 

92 H += coupling 

93 return H 

94 

95 def get_H(self): 

96 H_bare = self.get_H_bare() 

97 H_couplings = self.get_H_couplings() 

98 return H_bare + H_couplings 

99 

100 def calculate_eig(self): 

101 H = self.get_H() 

102 return calculate_eig(self.Ns, H)