Coverage for jaxquantum / core / operators.py: 87%
62 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-11 21:51 +0000
1"""States."""
3from typing import List
4from jax import config
5from math import prod
7import jax.numpy as jnp
8from jax.nn import one_hot
10from jaxquantum.core.qarray import Qarray, tensor
12config.update("jax_enable_x64", True)
15def sigmax() -> Qarray:
16 """σx
18 Returns:
19 σx Pauli Operator
20 """
21 return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]]))
24def sigmay() -> Qarray:
25 """σy
27 Returns:
28 σy Pauli Operator
29 """
30 return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]]))
33def sigmaz() -> Qarray:
34 """σz
36 Returns:
37 σz Pauli Operator
38 """
39 return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]]))
42def hadamard() -> Qarray:
43 """H
45 Returns:
46 H: Hadamard gate
47 """
48 return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2))
51def sigmam() -> Qarray:
52 """σ-
54 Returns:
55 σ- Pauli Operator
56 """
57 return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]]))
60def sigmap() -> Qarray:
61 """σ+
63 Returns:
64 σ+ Pauli Operator
65 """
66 return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]]))
69def qubit_rotation(theta: float, nx, ny, nz) -> Qarray:
70 """Single qubit rotation.
72 Args:
73 theta: rotation angle.
74 nx: rotation axis x component.
75 ny: rotation axis y component.
76 nz: rotation axis z component.
78 Returns:
79 Single qubit rotation operator.
80 """
81 return jnp.cos(theta / 2) * identity(2) - 1j * jnp.sin(theta / 2) * (
82 nx * sigmax() + ny * sigmay() + nz * sigmaz()
83 )
86def destroy(N) -> Qarray:
87 """annihilation operator
89 Args:
90 N: Hilbert space size
92 Returns:
93 annilation operator in Hilber Space of size N
94 """
95 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1))
98def create(N) -> Qarray:
99 """creation operator
101 Args:
102 N: Hilbert space size
104 Returns:
105 creation operator in Hilber Space of size N
106 """
107 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1))
110def num(N) -> Qarray:
111 """Number operator
113 Args:
114 N: Hilbert Space size
116 Returns:
117 number operator in Hilber Space of size N
118 """
119 return Qarray.create(jnp.diag(jnp.arange(N)))
122def identity(*args, **kwargs) -> Qarray:
123 """Identity matrix.
125 Returns:
126 Identity matrix.
127 """
128 return Qarray.create(jnp.eye(*args, **kwargs))
131def identity_like(A) -> Qarray:
132 """Identity matrix with the same shape as A.
134 Args:
135 A: Matrix.
137 Returns:
138 Identity matrix with the same shape as A.
139 """
140 space_dims = A.space_dims
141 total_dim = prod(space_dims)
142 return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims])
145def displace(N, α) -> Qarray:
146 """Displacement operator
148 Args:
149 N: Hilbert Space Size
150 α: Phase space displacement
152 Returns:
153 Displace operator D(α)
154 """
155 a = destroy(N)
156 return (α * a.dag() - jnp.conj(α) * a).expm()
158def squeeze(N, z):
159 """Single-mode Squeezing operator.
162 Args:
163 N: Hilbert Space Size
164 z: squeezing parameter
166 Returns:
167 Sqeezing operator
168 """
170 a = destroy(N)
171 op = (1 / 2.0) * jnp.conj(z) * (a @ a) - (1 / 2.0) * z * (a.dag() @ a.dag())
172 return op.expm()
175def squeezing_linear_to_dB(z):
176 return 20 * jnp.log10(jnp.exp(jnp.abs(z)))
178def squeezing_dB_to_linear(z_dB):
179 return jnp.log(10**(z_dB/20))
181# States ---------------------------------------------------------------------
184def basis(N: int, k: int):
185 """Creates a |k> (i.e. fock state) ket in a specified Hilbert Space.
187 Args:
188 N: Hilbert space dimension
189 k: fock number
191 Returns:
192 Fock State |k>
193 """
194 return Qarray.create(one_hot(k, N).reshape(N, 1))
196def multi_mode_basis_set(Ns: List[int]) -> Qarray:
197 """Creates a multi-mode basis set.
199 Args:
200 Ns: List of Hilbert space dimensions for each mode.
202 Returns:
203 Multi-mode basis set.
204 """
205 data = jnp.eye(prod(Ns))
206 dims = (tuple(Ns), tuple([1 for _ in Ns]))
207 return Qarray.create(data, dims=dims, bdims=(prod(Ns),))
210def coherent(N: int, α: complex) -> Qarray:
211 """Coherent state.
213 Args:
214 N: Hilbert Space Size.
215 α: coherent state amplitude.
217 Return:
218 Coherent state |α⟩.
219 """
220 return displace(N, α) @ basis(N, 0)
223def thermal_dm(N: int, n: float) -> Qarray:
224 """Thermal state.
226 Args:
227 N: Hilbert Space Size.
228 n: average photon number.
230 Return:
231 Thermal state.
232 """
234 beta = jnp.log(1 + 1 / n)
236 return Qarray.create(
237 jnp.where(
238 jnp.isposinf(beta),
239 basis(N, 0).to_dm().data,
240 jnp.diag(jnp.exp(-beta * jnp.linspace(0, N - 1, N))),
241 )
242 ).unit()
245def basis_like(A: Qarray, ks: List[int]) -> Qarray:
246 """Creates a |k> (i.e. fock state) ket with the same space dims as A.
248 Args:
249 A: state or operator.
250 k: fock number.
252 Returns:
253 Fock State |k> with the same space dims as A.
254 """
255 space_dims = A.space_dims
256 assert len(space_dims) == len(ks), "len(ks) must be equal to len(space_dims)"
258 kets = []
259 for j, k in enumerate(ks):
260 kets.append(basis(space_dims[j], k))
261 return tensor(*kets)