Coverage for jaxquantum/circuits/library/qubit.py: 20%
105 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1"""qubit gates."""
3from jaxquantum.core.operators import (
4 identity,
5 sigmax,
6 sigmay,
7 sigmaz,
8 basis,
9 hadamard,
10 qubit_rotation,
11)
12from jaxquantum.circuits.gates import Gate
13from jaxquantum.core.qarray import Qarray
14import jax.numpy as jnp
17def X():
18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1)
21def Y():
22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1)
25def Z():
26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1)
29def H():
30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1)
33def Rx(theta, ts=None):
34 gen_Ht = None
35 if ts is not None:
36 delta_t = ts[-1] - ts[0]
37 amp = theta / delta_t
38 gen_Ht = lambda params: (lambda t: amp / 2 * sigmax())
40 return Gate.create(
41 2,
42 name="Rx",
43 params={"theta": theta},
44 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0),
45 gen_Ht=gen_Ht,
46 ts=ts,
47 num_modes=1,
48 )
51def Ry(theta, ts=None):
52 gen_Ht = None
53 if ts is not None:
54 delta_t = ts[-1] - ts[0]
55 amp = theta / delta_t
56 gen_Ht = lambda params: (lambda t: amp / 2 * sigmay())
57 return Gate.create(
58 2,
59 name="Ry",
60 params={"theta": theta},
61 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0),
62 gen_Ht=gen_Ht,
63 ts=ts,
64 num_modes=1,
65 )
68def Rz(theta, ts=None):
69 gen_Ht = None
70 if ts is not None:
71 delta_t = ts[-1] - ts[0]
72 amp = theta / delta_t
73 gen_Ht = lambda params: (lambda t: amp / 2 * sigmaz())
74 return Gate.create(
75 2,
76 name="Rz",
77 params={"theta": theta},
78 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1),
79 gen_Ht=gen_Ht,
80 ts=ts,
81 num_modes=1,
82 )
85def MZ(measure=None):
86 g = basis(2, 0)
87 e = basis(2, 1)
89 gg = g @ g.dag()
90 ee = e @ e.dag()
92 if measure is None:
93 kmap = Qarray.from_list([gg, ee])
94 gate_name = "MZ"
95 elif measure == +1:
96 kmap = Qarray.from_list([gg])
97 gate_name = "MZ_plus"
98 elif measure == -1:
99 kmap = Qarray.from_list([ee])
100 gate_name = "MZ_minus"
101 else:
102 raise ValueError("measure should be None, +1 or -1")
104 return Gate.create(2, name=gate_name, gen_KM=lambda params: kmap, num_modes=1)
107def MX(measure=None):
108 g = basis(2, 0)
109 e = basis(2, 1)
111 plus = (g + e).unit()
112 minus = (g - e).unit()
114 pp = plus @ plus.dag()
115 mm = minus @ minus.dag()
117 if measure is None:
118 kmap = Qarray.from_list([pp, mm])
119 gate_name = "MX"
120 elif measure == +1:
121 kmap = Qarray.from_list([pp])
122 gate_name = "MX_plus"
123 elif measure == -1:
124 kmap = Qarray.from_list([mm])
125 gate_name = "MX_minus"
126 else:
127 raise ValueError("measure should be None, +1 or -1")
129 return Gate.create(2, name=gate_name, gen_KM=lambda params: kmap, num_modes=1)
132def Reset():
133 g = basis(2, 0)
134 e = basis(2, 1)
136 gg = g @ g.dag()
137 ge = g @ e.dag()
139 kmap = Qarray.from_list([gg, ge])
140 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1)
143def IP_Reset(p_eg, p_ee):
144 g = basis(2, 0)
145 e = basis(2, 1)
147 gg = g @ g.dag()
148 ge = g @ e.dag()
149 eg = e @ g.dag()
150 ee = e @ e.dag()
152 k_0 = jnp.sqrt(1 - p_eg) * gg
153 k_1 = jnp.sqrt(p_ee) * ee
154 k_2 = jnp.sqrt(p_eg) * eg
155 k_3 = jnp.sqrt(1 - p_ee) * ge
157 kmap = Qarray.from_list([k_0, k_1, k_2, k_3])
159 return Gate.create(
160 2,
161 name="IP_Reset",
162 params={"p_eg": p_eg, "p_ge": p_ee},
163 gen_KM=lambda params: kmap,
164 num_modes=1,
165 )
168def CX():
169 g = basis(2, 0)
170 e = basis(2, 1)
172 gg = g @ g.dag()
173 ee = e @ e.dag()
175 op = (gg ^ identity(2)) + (ee ^ sigmax())
177 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2)
180def _Thermal_Kraus_Ops_Qb(err_prob, n_bar):
181 """ " Returns the Kraus Operators for a thermal channel with probability
182 err_prob and average photon number n_bar in a Hilbert Space of size 2"""
183 p = n_bar / (n_bar + 1)
184 return [
185 Qarray.create(
186 jnp.sqrt(1 - p) * jnp.array([[1, 0], [0, jnp.sqrt(1 - err_prob)]])
187 ),
188 Qarray.create(jnp.sqrt(1 - p) * jnp.array([[0, jnp.sqrt(err_prob)], [0, 0]])),
189 Qarray.create(jnp.sqrt(p) * jnp.array([[0, 0], [jnp.sqrt(err_prob), 0]])),
190 Qarray.create(jnp.sqrt(p) * jnp.array([[jnp.sqrt(1 - err_prob), 0], [0, 1]])),
191 ]
194def Thermal_Ch_Qb(err_prob, n_bar):
195 kmap = lambda params: Qarray.from_list(_Thermal_Kraus_Ops_Qb(err_prob,
196 n_bar))
197 return Gate.create(
198 2,
199 name="Thermal_Ch_Qb",
200 params={"err_prob": err_prob, "n_bar": n_bar},
201 gen_KM=kmap,
202 num_modes=1,
203 )
206def _Pure_Dephasing_Ops_Qb(err_prob):
207 """ " Returns the Kraus Operators for a thermal channel with probability
208 err_prob and average photon number n_bar in a Hilbert Space of size 2"""
209 return [
210 jnp.sqrt(1-err_prob)*identity(2),
211 jnp.sqrt(err_prob)*sigmaz()
212 ]
215def Dephasing_Ch_Qb(err_prob):
216 kmap = lambda params: Qarray.from_list(_Pure_Dephasing_Ops_Qb(err_prob))
217 return Gate.create(
218 2,
219 name="Dephasing_Ch_Qb",
220 params={"err_prob": err_prob},
221 gen_KM=kmap,
222 num_modes=1,
223 )