Coverage for jaxquantum/circuits/library/qubit.py: 22%
87 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""qubit gates."""
3from jaxquantum.core.operators import (
4 identity,
5 sigmax,
6 sigmay,
7 sigmaz,
8 basis,
9 hadamard,
10 qubit_rotation,
11)
12from jaxquantum.circuits.gates import Gate
13from jaxquantum.core.qarray import Qarray
14import jax.numpy as jnp
17def X():
18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1)
21def Y():
22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1)
25def Z():
26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1)
29def H():
30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1)
33def Rx(theta, ts=None):
35 gen_Ht = None
36 if ts is not None:
37 delta_t = ts[-1] - ts[0]
38 amp = theta / delta_t
39 gen_Ht = lambda params: (
40 lambda t: amp / 2 * sigmax())
42 return Gate.create(
43 2,
44 name="Rx",
45 params={"theta": theta},
46 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0),
47 gen_Ht=gen_Ht,
48 ts=ts,
49 num_modes=1,
50 )
53def Ry(theta, ts=None):
54 gen_Ht = None
55 if ts is not None:
56 delta_t = ts[-1] - ts[0]
57 amp = theta / delta_t
58 gen_Ht = lambda params: (
59 lambda t: amp / 2 * sigmay())
60 return Gate.create(
61 2,
62 name="Ry",
63 params={"theta": theta},
64 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0),
65 gen_Ht=gen_Ht,
66 ts=ts,
67 num_modes=1,
68 )
71def Rz(theta, ts=None):
72 gen_Ht = None
73 if ts is not None:
74 delta_t = ts[-1] - ts[0]
75 amp = theta / delta_t
76 gen_Ht = lambda params: (
77 lambda t: amp / 2 * sigmaz())
78 return Gate.create(
79 2,
80 name="Rz",
81 params={"theta": theta},
82 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1),
83 gen_Ht=gen_Ht,
84 ts=ts,
85 num_modes=1,
86 )
89def MZ():
90 g = basis(2, 0)
91 e = basis(2, 1)
93 gg = g @ g.dag()
94 ee = e @ e.dag()
96 kmap = Qarray.from_list([gg, ee])
98 return Gate.create(2, name="MZ", gen_KM=lambda params: kmap, num_modes=1)
101def MX():
102 g = basis(2, 0)
103 e = basis(2, 1)
105 plus = (g + e).unit()
106 minus = (g - e).unit()
108 pp = plus @ plus.dag()
109 mm = minus @ minus.dag()
111 kmap = Qarray.from_list([pp, mm])
113 return Gate.create(2, name="MX", gen_KM=lambda params: kmap, num_modes=1)
116def MX_plus():
117 g = basis(2, 0)
118 e = basis(2, 1)
119 plus = (g + e).unit()
120 pp = plus @ plus.dag()
121 kmap = Qarray.from_list([pp])
123 return Gate.create(2, name="MXplus", gen_KM=lambda params: kmap, num_modes=1)
126def MZ_plus():
127 g = basis(2, 0)
128 plus = g
129 pp = plus @ plus.dag()
130 kmap = Qarray.from_list([pp])
132 return Gate.create(2, name="MZplus", gen_KM=lambda params: kmap, num_modes=1)
135def Reset():
136 g = basis(2, 0)
137 e = basis(2, 1)
139 gg = g @ g.dag()
140 ge = g @ e.dag()
142 kmap = Qarray.from_list([gg, ge])
143 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1)
146def IP_Reset(p_eg, p_ee):
147 g = basis(2, 0)
148 e = basis(2, 1)
150 gg = g @ g.dag()
151 ge = g @ e.dag()
152 eg = e @ g.dag()
153 ee = e @ e.dag()
155 k_0 = jnp.sqrt(1 - p_eg) * gg + jnp.sqrt(p_eg) * eg
156 k_1 = jnp.sqrt(p_ee) * ee + jnp.sqrt(1 - p_ee) * ge
158 kmap = Qarray.from_list([k_0, k_1])
160 return Gate.create(
161 2,
162 name="IP_Reset",
163 params={"p_eg": p_eg, "p_ge": p_ee},
164 gen_KM=lambda params: kmap,
165 num_modes=1,
166 )
169def CX():
170 g = basis(2, 0)
171 e = basis(2, 1)
173 gg = g @ g.dag()
174 ee = e @ e.dag()
176 op = (gg ^ identity(2)) + (ee ^ sigmax())
178 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2)