Coverage for jaxquantum/devices/base/system.py: 0%
68 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""System."""
3from typing import List, Optional, Dict, Any, Union
4import math
6from flax import struct
7from jax import vmap, Array
8from jax import config
10import jax.numpy as jnp
12from jaxquantum.devices.base.base import Device
13from jaxquantum.devices.superconducting.drive import Drive
14from jaxquantum.core.qarray import Qarray, tensor
15from jaxquantum.core.operators import identity
17config.update("jax_enable_x64", True)
20def calculate_eig(Ns, H: Qarray):
21 N_tot = math.prod(Ns)
22 edxs = jnp.arange(N_tot)
24 vals, kets = jnp.linalg.eigh(H.data)
25 kets = kets.T
27 def calc_quantum_number(edx):
28 argmax = jnp.argmax(jnp.abs(kets[edx]))
29 val = vals[edx] # - vals[0]
30 return val, argmax, kets[edx]
32 quantum_numbers = vmap(calc_quantum_number)(edxs)
34 def calc_order(edx):
35 indx = jnp.argmin(jnp.abs(edx - quantum_numbers[1]))
36 return quantum_numbers[0][indx], quantum_numbers[2][indx]
38 Es, kets = vmap(calc_order)(edxs)
40 kets = jnp.reshape(kets, (N_tot, N_tot, 1))
41 kets = Qarray.create(kets)
42 kets = kets.reshape_qdims(*Ns)
43 kets = kets.reshape_bdims(*Ns)
45 return (
46 jnp.reshape(Es, Ns),
47 kets,
48 )
51def promote(op: Qarray, device_num, Ns):
52 I_ops = [identity(N) for N in Ns]
53 return tensor(*I_ops[:device_num], op, *I_ops[device_num + 1 :])
56@struct.dataclass
57class System:
58 Ns: List[int] = struct.field(pytree_node=False)
59 devices: List[Union[Device, Drive]]
60 couplings: List[Array]
61 params: Dict[str, Any]
63 @classmethod
64 def create(
65 cls,
66 devices: List[Union[Device, Drive]],
67 couplings: Optional[List[Array]] = None,
68 params: Optional[Dict[str, Any]] = None,
69 ):
70 labels = [device.label for device in devices]
71 unique_labels = set(labels)
72 if len(labels) != len(unique_labels):
73 raise ValueError("Devices must have unique labels.")
75 Ns = tuple([device.N for device in devices])
76 couplings = couplings if couplings is not None else []
77 params = params if params is not None else {}
78 return cls(Ns, devices, couplings, params)
80 def promote(self, op, device_num):
81 return promote(op, device_num, self.Ns)
83 def get_H_bare(self):
84 H = 0
85 for j, device in enumerate(self.devices):
86 H += self.promote(device.get_H(), j)
87 return H
89 def get_H_couplings(self):
90 H = 0
91 for coupling in self.couplings:
92 H += coupling
93 return H
95 def get_H(self):
96 H_bare = self.get_H_bare()
97 H_couplings = self.get_H_couplings()
98 return H_bare + H_couplings
100 def calculate_eig(self):
101 H = self.get_H()
102 return calculate_eig(self.Ns, H)