Coverage for jaxquantum/circuits/library/qubit.py: 0%
72 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""qubit gates."""
3from jaxquantum.core.operators import (
4 identity,
5 sigmax,
6 sigmay,
7 sigmaz,
8 basis,
9 hadamard,
10 qubit_rotation,
11)
12from jaxquantum.circuits.gates import Gate
13from jaxquantum.core.qarray import Qarray
14import jax.numpy as jnp
17def X():
18 return Gate.create(2, name="X", gen_U=lambda params: sigmax(), num_modes=1)
21def Y():
22 return Gate.create(2, name="Y", gen_U=lambda params: sigmay(), num_modes=1)
25def Z():
26 return Gate.create(2, name="Z", gen_U=lambda params: sigmaz(), num_modes=1)
29def H():
30 return Gate.create(2, name="H", gen_U=lambda params: hadamard(), num_modes=1)
33def Rx(theta):
34 return Gate.create(
35 2,
36 name="Rx",
37 params={"theta": theta},
38 gen_U=lambda params: qubit_rotation(params["theta"], 1, 0, 0),
39 num_modes=1,
40 )
43def Ry(theta):
44 return Gate.create(
45 2,
46 name="Ry",
47 params={"theta": theta},
48 gen_U=lambda params: qubit_rotation(params["theta"], 0, 1, 0),
49 num_modes=1,
50 )
53def Rz(theta):
54 return Gate.create(
55 2,
56 name="Rz",
57 params={"theta": theta},
58 gen_U=lambda params: qubit_rotation(params["theta"], 0, 0, 1),
59 num_modes=1,
60 )
63def MZ():
64 g = basis(2, 0)
65 e = basis(2, 1)
67 gg = g @ g.dag()
68 ee = e @ e.dag()
70 kmap = Qarray.from_list([gg, ee])
72 return Gate.create(2, name="MZ", gen_KM=lambda params: kmap, num_modes=1)
75def MX():
76 g = basis(2, 0)
77 e = basis(2, 1)
79 plus = (g + e).unit()
80 minus = (g - e).unit()
82 pp = plus @ plus.dag()
83 mm = minus @ minus.dag()
85 kmap = Qarray.from_list([pp, mm])
87 return Gate.create(2, name="MX", gen_KM=lambda params: kmap, num_modes=1)
90def MX_plus():
91 g = basis(2, 0)
92 e = basis(2, 1)
93 plus = (g + e).unit()
94 pp = plus @ plus.dag()
95 kmap = Qarray.from_list([2 * pp])
97 return Gate.create(2, name="MXplus", gen_KM=lambda params: kmap, num_modes=1)
100def MZ_plus():
101 g = basis(2, 0)
102 plus = g
103 pp = plus @ plus.dag()
104 kmap = Qarray.from_list([2 * pp])
106 return Gate.create(2, name="MZplus", gen_KM=lambda params: kmap, num_modes=1)
109def Reset():
110 g = basis(2, 0)
111 e = basis(2, 1)
113 gg = g @ g.dag()
114 ge = g @ e.dag()
116 kmap = Qarray.from_list([gg, ge])
117 return Gate.create(2, name="Reset", gen_KM=lambda params: kmap, num_modes=1)
120def IP_Reset(p_eg, p_ee):
121 g = basis(2, 0)
122 e = basis(2, 1)
124 gg = g @ g.dag()
125 ge = g @ e.dag()
126 eg = e @ g.dag()
127 ee = e @ e.dag()
129 k_0 = jnp.sqrt(1 - p_eg) * gg + jnp.sqrt(p_eg) * eg
130 k_1 = jnp.sqrt(p_ee) * ee + jnp.sqrt(1 - p_ee) * ge
132 kmap = Qarray.from_list([k_0, k_1])
134 return Gate.create(
135 2,
136 name="IP_Reset",
137 params={"p_eg": p_eg, "p_ge": p_ee},
138 gen_KM=lambda params: kmap,
139 num_modes=1,
140 )
143def CX():
144 g = basis(2, 0)
145 e = basis(2, 1)
147 gg = g @ g.dag()
148 ee = e @ e.dag()
150 op = (gg ^ identity(2)) + (ee ^ sigmax())
152 return Gate.create([2, 2], name="CX", gen_U=lambda params: op, num_modes=2)