Coverage for jaxquantum/circuits/gates.py: 84%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""Gates.""" 

2 

3from copy import deepcopy 

4from flax import struct 

5from jax import Array, config 

6from typing import List, Dict, Any, Optional, Callable, Union 

7import jax.numpy as jnp 

8 

9 

10from jaxquantum.core.qarray import Qarray, concatenate 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15@struct.dataclass 

16class Gate: 

17 dims: List[int] = struct.field(pytree_node=False) 

18 _U: Optional[Array] # Unitary 

19 _Ht: Optional[Array] # Hamiltonian 

20 _KM: Optional[Qarray] # Kraus map 

21 _c_ops: Optional[Qarray] 

22 _params: Dict[str, Any] 

23 _ts: Array 

24 _name: str = struct.field(pytree_node=False) 

25 num_modes: int = struct.field(pytree_node=False) 

26 

27 @classmethod 

28 def create( 

29 cls, 

30 dims: Union[int, List[int]], 

31 name: str = "Gate", 

32 params: Optional[Dict[str, Any]] = None, 

33 ts: Optional[Array] = None, 

34 gen_U: Optional[Callable[[Dict[str, Any]], Qarray]] = None, 

35 gen_Ht: Optional[Callable[[Dict[str, Any]], Qarray]] = None, 

36 gen_c_ops: Optional[Callable[[Dict[str, Any]], Qarray]] = None, 

37 gen_KM: Optional[Callable[[Dict[str, Any]], List[Qarray]]] = None, 

38 num_modes: int = 1, 

39 ): 

40 """Create a gate. 

41 

42 Args: 

43 dims: Dimensions of the gate. 

44 name: Name of the gate. 

45 params: Parameters of the gate. 

46 ts: Times of the gate. 

47 gen_U: Function to generate the unitary of the gate. 

48 gen_Ht: Function to generate a function Ht(t) that takes in a time t and outputs a Hamiltonian Qarray. 

49 gen_KM: Function to generate the Kraus map of the gate. 

50 num_modes: Number of modes of the gate. 

51 """ 

52 

53 # TODO: add params to device? 

54 

55 if isinstance(dims, int): 

56 dims = [dims] 

57 

58 assert len(dims) == num_modes, ( 

59 "Number of dimensions must match number of modes." 

60 ) 

61 

62 # Unitary 

63 _U = gen_U(params) if gen_U is not None else None 

64 _Ht = gen_Ht(params) if gen_Ht is not None else None 

65 _c_ops = gen_c_ops(params) if gen_c_ops is not None else Qarray.from_list([]) 

66 

67 if gen_KM is not None: 

68 _KM = gen_KM(params) 

69 elif _U is not None: 

70 _KM = Qarray.from_list([_U]) 

71 

72 return Gate( 

73 dims = dims, 

74 _U = _U, 

75 _Ht = _Ht, 

76 _KM = _KM, 

77 _c_ops = _c_ops, 

78 _params = params if params is not None else {}, 

79 _ts=ts if ts is not None else jnp.array([]), 

80 _name=name, 

81 num_modes=num_modes, 

82 ) 

83 

84 def __str__(self): 

85 return self._name 

86 

87 def __repr__(self): 

88 return self._name 

89 

90 @property 

91 def name(self): 

92 return self._name 

93 

94 @property 

95 def U(self): 

96 return self._U 

97 

98 @property 

99 def Ht(self): 

100 return self._Ht 

101 

102 @property 

103 def KM(self): 

104 return self._KM 

105 

106 @property 

107 def c_ops(self): 

108 return self._c_ops 

109 

110 @property 

111 def params(self): 

112 return self._params 

113 

114 @property 

115 def ts(self): 

116 return self._ts 

117 

118 def add_Ht(self, Ht: Callable[[float], Qarray]): 

119 """Add a Hamiltonian function to the gate.""" 

120 def new_Ht(t): 

121 return Ht(t) + self.Ht(t) if self.Ht is not None else Ht(t) 

122 

123 return Gate( 

124 dims = self.dims, 

125 _U = self.U, 

126 _Ht = new_Ht, 

127 _KM = self.KM, 

128 _c_ops = self.c_ops, 

129 _params = self.params, 

130 _ts = self.ts, 

131 _name = self.name, 

132 num_modes = self.num_modes, 

133 ) 

134 

135 def add_c_ops(self, c_ops: Qarray): 

136 """Add a c_ops to the gate.""" 

137 return Gate( 

138 dims = self.dims, 

139 _U = self.U, 

140 _Ht = self.Ht, 

141 _KM = self.KM, 

142 _c_ops = concatenate([self.c_ops, c_ops]), 

143 _params = self.params, 

144 _ts = self.ts, 

145 _name = self.name, 

146 num_modes = self.num_modes, 

147 ) 

148 

149 def copy(self): 

150 """Return a copy of the gate.""" 

151 return Gate( 

152 dims = deepcopy(self.dims), 

153 _U = self.U, 

154 _Ht = deepcopy(self.Ht), 

155 _KM = self.KM, 

156 _c_ops = self.c_ops, 

157 _params = deepcopy(self.params), 

158 _ts = self.ts, 

159 _name = self.name, 

160 num_modes = self.num_modes, 

161 )