Coverage for jaxquantum/circuits/gates.py: 0%
41 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Gates."""
3from flax import struct
4from jax import Array, config
5from typing import List, Dict, Any, Optional, Callable, Union
6import jax.numpy as jnp
9from jaxquantum.core.qarray import Qarray
11config.update("jax_enable_x64", True)
14@struct.dataclass
15class Gate:
16 dims: List[int] = struct.field(pytree_node=False)
17 _U: Optional[Array] # Unitary
18 _H: Optional[Array] # Hamiltonian
19 _KM: Optional[Qarray] # Kraus map
20 _params: Dict[str, Any]
21 _ts: Array
22 _name: str = struct.field(pytree_node=False)
23 num_modes: int = struct.field(pytree_node=False)
25 @classmethod
26 def create(
27 cls,
28 dims: Union[int, List[int]],
29 name: str = "Gate",
30 params: Optional[Dict[str, Any]] = None,
31 ts: Optional[Array] = None,
32 gen_U: Optional[Callable[[Dict[str, Any]], Qarray]] = None,
33 gen_H: Optional[Callable[[Dict[str, Any]], Qarray]] = None,
34 gen_KM: Optional[Callable[[Dict[str, Any]], List[Qarray]]] = None,
35 num_modes: int = 1,
36 ):
37 """Create a gate.
39 Args:
40 dims: Dimensions of the gate.
41 name: Name of the gate.
42 params: Parameters of the gate.
43 ts: Times of the gate.
44 gen_U: Function to generate the unitary of the gate.
45 gen_H: Function to generate the Hamiltonian of the gate.
46 gen_KM: Function to generate the Kraus map of the gate.
47 num_modes: Number of modes of the gate.
48 """
50 # TODO: add params to device?
52 if isinstance(dims, int):
53 dims = [dims]
55 assert len(dims) == num_modes, (
56 "Number of dimensions must match number of modes."
57 )
59 # Unitary
60 _U = gen_U(params) if gen_U is not None else None
61 _H = gen_H(params) if gen_H is not None else None
63 if gen_KM is not None:
64 _KM = gen_KM(params)
65 elif _U is not None:
66 _KM = Qarray.from_list([_U])
68 return Gate(
69 dims=dims,
70 _U=_U,
71 _H=_H,
72 _KM=_KM,
73 _params=params if params is not None else {},
74 _ts=ts if ts is not None else jnp.array([]),
75 _name=name,
76 num_modes=num_modes,
77 )
79 def __str__(self):
80 return self._name
82 def __repr__(self):
83 return self._name
85 @property
86 def U(self):
87 return self._U
89 @property
90 def H(self):
91 return self._H
93 @property
94 def KM(self):
95 return self._KM