Coverage for jaxquantum/circuits/gates.py: 84%
64 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""Gates."""
3from copy import deepcopy
4from flax import struct
5from jax import Array, config
6from typing import List, Dict, Any, Optional, Callable, Union
7import jax.numpy as jnp
10from jaxquantum.core.qarray import Qarray, concatenate
12config.update("jax_enable_x64", True)
15@struct.dataclass
16class Gate:
17 dims: List[int] = struct.field(pytree_node=False)
18 _U: Optional[Array] # Unitary
19 _Ht: Optional[Array] # Hamiltonian
20 _KM: Optional[Qarray] # Kraus map
21 _c_ops: Optional[Qarray]
22 _params: Dict[str, Any]
23 _ts: Array
24 _name: str = struct.field(pytree_node=False)
25 num_modes: int = struct.field(pytree_node=False)
27 @classmethod
28 def create(
29 cls,
30 dims: Union[int, List[int]],
31 name: str = "Gate",
32 params: Optional[Dict[str, Any]] = None,
33 ts: Optional[Array] = None,
34 gen_U: Optional[Callable[[Dict[str, Any]], Qarray]] = None,
35 gen_Ht: Optional[Callable[[Dict[str, Any]], Qarray]] = None,
36 gen_c_ops: Optional[Callable[[Dict[str, Any]], Qarray]] = None,
37 gen_KM: Optional[Callable[[Dict[str, Any]], List[Qarray]]] = None,
38 num_modes: int = 1,
39 ):
40 """Create a gate.
42 Args:
43 dims: Dimensions of the gate.
44 name: Name of the gate.
45 params: Parameters of the gate.
46 ts: Times of the gate.
47 gen_U: Function to generate the unitary of the gate.
48 gen_Ht: Function to generate a function Ht(t) that takes in a time t and outputs a Hamiltonian Qarray.
49 gen_KM: Function to generate the Kraus map of the gate.
50 num_modes: Number of modes of the gate.
51 """
53 # TODO: add params to device?
55 if isinstance(dims, int):
56 dims = [dims]
58 assert len(dims) == num_modes, (
59 "Number of dimensions must match number of modes."
60 )
62 # Unitary
63 _U = gen_U(params) if gen_U is not None else None
64 _Ht = gen_Ht(params) if gen_Ht is not None else None
65 _c_ops = gen_c_ops(params) if gen_c_ops is not None else Qarray.from_list([])
67 if gen_KM is not None:
68 _KM = gen_KM(params)
69 elif _U is not None:
70 _KM = Qarray.from_list([_U])
72 return Gate(
73 dims = dims,
74 _U = _U,
75 _Ht = _Ht,
76 _KM = _KM,
77 _c_ops = _c_ops,
78 _params = params if params is not None else {},
79 _ts=ts if ts is not None else jnp.array([]),
80 _name=name,
81 num_modes=num_modes,
82 )
84 def __str__(self):
85 return self._name
87 def __repr__(self):
88 return self._name
90 @property
91 def name(self):
92 return self._name
94 @property
95 def U(self):
96 return self._U
98 @property
99 def Ht(self):
100 return self._Ht
102 @property
103 def KM(self):
104 return self._KM
106 @property
107 def c_ops(self):
108 return self._c_ops
110 @property
111 def params(self):
112 return self._params
114 @property
115 def ts(self):
116 return self._ts
118 def add_Ht(self, Ht: Callable[[float], Qarray]):
119 """Add a Hamiltonian function to the gate."""
120 def new_Ht(t):
121 return Ht(t) + self.Ht(t) if self.Ht is not None else Ht(t)
123 return Gate(
124 dims = self.dims,
125 _U = self.U,
126 _Ht = new_Ht,
127 _KM = self.KM,
128 _c_ops = self.c_ops,
129 _params = self.params,
130 _ts = self.ts,
131 _name = self.name,
132 num_modes = self.num_modes,
133 )
135 def add_c_ops(self, c_ops: Qarray):
136 """Add a c_ops to the gate."""
137 return Gate(
138 dims = self.dims,
139 _U = self.U,
140 _Ht = self.Ht,
141 _KM = self.KM,
142 _c_ops = concatenate([self.c_ops, c_ops]),
143 _params = self.params,
144 _ts = self.ts,
145 _name = self.name,
146 num_modes = self.num_modes,
147 )
149 def copy(self):
150 """Return a copy of the gate."""
151 return Gate(
152 dims = deepcopy(self.dims),
153 _U = self.U,
154 _Ht = deepcopy(self.Ht),
155 _KM = self.KM,
156 _c_ops = self.c_ops,
157 _params = deepcopy(self.params),
158 _ts = self.ts,
159 _name = self.name,
160 num_modes = self.num_modes,
161 )