Coverage for jaxquantum/circuits/gates.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Gates.""" 

2 

3from flax import struct 

4from jax import Array, config 

5from typing import List, Dict, Any, Optional, Callable, Union 

6import jax.numpy as jnp 

7 

8 

9from jaxquantum.core.qarray import Qarray 

10 

11config.update("jax_enable_x64", True) 

12 

13 

14@struct.dataclass 

15class Gate: 

16 dims: List[int] = struct.field(pytree_node=False) 

17 _U: Optional[Array] # Unitary 

18 _H: Optional[Array] # Hamiltonian 

19 _KM: Optional[Qarray] # Kraus map 

20 _params: Dict[str, Any] 

21 _ts: Array 

22 _name: str = struct.field(pytree_node=False) 

23 num_modes: int = struct.field(pytree_node=False) 

24 

25 @classmethod 

26 def create( 

27 cls, 

28 dims: Union[int, List[int]], 

29 name: str = "Gate", 

30 params: Optional[Dict[str, Any]] = None, 

31 ts: Optional[Array] = None, 

32 gen_U: Optional[Callable[[Dict[str, Any]], Qarray]] = None, 

33 gen_H: Optional[Callable[[Dict[str, Any]], Qarray]] = None, 

34 gen_KM: Optional[Callable[[Dict[str, Any]], List[Qarray]]] = None, 

35 num_modes: int = 1, 

36 ): 

37 """Create a gate. 

38 

39 Args: 

40 dims: Dimensions of the gate. 

41 name: Name of the gate. 

42 params: Parameters of the gate. 

43 ts: Times of the gate. 

44 gen_U: Function to generate the unitary of the gate. 

45 gen_H: Function to generate the Hamiltonian of the gate. 

46 gen_KM: Function to generate the Kraus map of the gate. 

47 num_modes: Number of modes of the gate. 

48 """ 

49 

50 # TODO: add params to device? 

51 

52 if isinstance(dims, int): 

53 dims = [dims] 

54 

55 assert len(dims) == num_modes, ( 

56 "Number of dimensions must match number of modes." 

57 ) 

58 

59 # Unitary 

60 _U = gen_U(params) if gen_U is not None else None 

61 _H = gen_H(params) if gen_H is not None else None 

62 

63 if gen_KM is not None: 

64 _KM = gen_KM(params) 

65 elif _U is not None: 

66 _KM = Qarray.from_list([_U]) 

67 

68 return Gate( 

69 dims=dims, 

70 _U=_U, 

71 _H=_H, 

72 _KM=_KM, 

73 _params=params if params is not None else {}, 

74 _ts=ts if ts is not None else jnp.array([]), 

75 _name=name, 

76 num_modes=num_modes, 

77 ) 

78 

79 def __str__(self): 

80 return self._name 

81 

82 def __repr__(self): 

83 return self._name 

84 

85 @property 

86 def U(self): 

87 return self._U 

88 

89 @property 

90 def H(self): 

91 return self._H 

92 

93 @property 

94 def KM(self): 

95 return self._KM