Coverage for jaxquantum / core / operators.py: 92%
102 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""States."""
3from typing import List
4from jax import config
5from math import prod
7import jax.numpy as jnp
8from jax.nn import one_hot
10from jaxquantum.core.qarray import Qarray, tensor, QarrayImplType
12config.update("jax_enable_x64", True)
15def _make_sparsedia(offsets: tuple, diags: "jnp.ndarray", dims=None) -> Qarray:
16 """Build a ``Qarray[SparseDiaImpl]`` directly from padded diagonal arrays.
18 Avoids going through a dense intermediate (no ``jnp.diag`` round-trip).
19 ``diags`` must already follow Convention A: diagonal at offset k has
20 leading zeros at [0:k] (k ≥ 0) or trailing zeros at [n+k:] (k < 0).
22 Args:
23 offsets: Sorted tuple of integer diagonal offsets.
24 diags: JAX array of shape (n_diags, n) with padded values.
25 dims: Optional quantum dims tuple.
27 Returns:
28 A ``Qarray`` backed by ``SparseDiaImpl``.
29 """
30 from jaxquantum.core.sparse_dia import SparseDiaImpl
32 impl = SparseDiaImpl.from_diags(offsets=offsets, diags=diags)
33 return Qarray.create(impl.get_data(), dims=dims, implementation=QarrayImplType.SPARSE_DIA)
36def sigmax(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
37 """σx
39 Args:
40 implementation: Qarray implementation type, e.g. "sparse" or "dense".
42 Returns:
43 σx Pauli Operator
44 """
45 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
46 # Offset -1: valid at [0:1] → diag[0] = A[1,0] = 1.0, diag[1] = 0 (trailing zero)
47 # Offset +1: valid at [1:] → diag[0] = 0 (leading zero), diag[1] = A[0,1] = 1.0
48 diags = jnp.array([[1.0, 0.0], [0.0, 1.0]])
49 return _make_sparsedia(offsets=(-1, 1), diags=diags)
50 return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]]), implementation=implementation)
53def sigmay(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
54 """σy
56 Returns:
57 σy Pauli Operator
58 """
59 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
60 diags = jnp.array([[1.0j, 0.0], [0.0, -1.0j]])
61 return _make_sparsedia(offsets=(-1, 1), diags=diags)
62 return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]]), implementation=implementation)
65def sigmaz(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
66 """σz
68 Returns:
69 σz Pauli Operator
70 """
71 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
72 diags = jnp.array([[1.0, -1.0]])
73 return _make_sparsedia(offsets=(0,), diags=diags)
74 return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]]), implementation=implementation)
77def hadamard(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
78 """H
80 Returns:
81 H: Hadamard gate
82 """
83 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
84 s = 1.0 / jnp.sqrt(2.0)
85 # offset -1: valid at [0] → diag[0]=A[1,0]=s, diag[1]=0 (trailing zero)
86 # offset 0: valid at [0:2] → diag[0]=A[0,0]=s, diag[1]=A[1,1]=-s
87 # offset +1: valid at [1] → diag[0]=0 (leading zero), diag[1]=A[0,1]=s
88 diags = jnp.array([[s, 0.0], [s, -s], [0.0, s]])
89 return _make_sparsedia(offsets=(-1, 0, 1), diags=diags)
90 return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2), implementation=implementation)
93def sigmam(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
94 """σ-
96 Returns:
97 σ- Pauli Operator
98 """
99 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
100 diags = jnp.array([[1.0, 0.0]])
101 return _make_sparsedia(offsets=(-1,), diags=diags)
102 return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]]), implementation=implementation)
105def sigmap(implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
106 """σ+
108 Returns:
109 σ+ Pauli Operator
110 """
111 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
112 diags = jnp.array([[0.0, 1.0]])
113 return _make_sparsedia(offsets=(1,), diags=diags)
114 return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]]), implementation=implementation)
117def qubit_rotation(theta: float, nx, ny, nz) -> Qarray:
118 """Single qubit rotation.
120 Args:
121 theta: rotation angle.
122 nx: rotation axis x component.
123 ny: rotation axis y component.
124 nz: rotation axis z component.
126 Returns:
127 Single qubit rotation operator.
128 """
129 return jnp.cos(theta / 2) * identity(2) - 1j * jnp.sin(theta / 2) * (
130 nx * sigmax() + ny * sigmay() + nz * sigmaz()
131 )
134def destroy(N, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
135 """annihilation operator
137 Args:
138 N: Hilbert space size
139 implementation: Qarray implementation type, e.g. "sparse" or "dense".
141 Returns:
142 annilation operator in Hilber Space of size N
143 """
144 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
145 # Single superdiagonal at offset +1; Convention A: 1 leading zero.
146 diags = jnp.zeros((1, N), dtype=jnp.float64)
147 diags = diags.at[0, 1:].set(jnp.sqrt(jnp.arange(1, N, dtype=jnp.float64)))
148 return _make_sparsedia(offsets=(1,), diags=diags)
149 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1), implementation=implementation)
152def create(N, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
153 """creation operator
155 Args:
156 N: Hilbert space size
157 implementation: Qarray implementation type, e.g. "sparse" or "dense".
159 Returns:
160 creation operator in Hilber Space of size N
161 """
162 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
163 # Single subdiagonal at offset -1; Convention A: 1 trailing zero.
164 diags = jnp.zeros((1, N), dtype=jnp.float64)
165 diags = diags.at[0, :N - 1].set(jnp.sqrt(jnp.arange(1, N, dtype=jnp.float64)))
166 return _make_sparsedia(offsets=(-1,), diags=diags)
167 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1), implementation=implementation)
170def num(N, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
171 """Number operator
173 Args:
174 N: Hilbert Space size
175 implementation: Qarray implementation type, e.g. "sparse" or "dense".
177 Returns:
178 number operator in Hilber Space of size N
179 """
180 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
181 # Main diagonal only; no leading/trailing zeros needed (offset 0).
182 diags = jnp.arange(N, dtype=jnp.float64).reshape(1, N)
183 return _make_sparsedia(offsets=(0,), diags=diags)
184 return Qarray.create(jnp.diag(jnp.arange(N)), implementation=implementation)
187def identity(*args, implementation: QarrayImplType = QarrayImplType.DENSE, **kwargs) -> Qarray:
188 """Identity matrix.
190 Args:
191 implementation: Qarray implementation type, e.g. "sparse" or "dense".
193 Returns:
194 Identity matrix.
195 """
196 if QarrayImplType(implementation) == QarrayImplType.SPARSE_DIA:
197 # jnp.eye(*args) is typically eye(N) or eye(N, N); extract N from args.
198 n = args[0] if args else kwargs.get("N", kwargs.get("n", None))
199 if n is not None and (len(args) <= 1) and not kwargs:
200 diags = jnp.ones((1, int(n)), dtype=jnp.float64)
201 return _make_sparsedia(offsets=(0,), diags=diags)
202 return Qarray.create(jnp.eye(*args, **kwargs), implementation=implementation)
205qeye = identity
207def identity_like(A, implementation: QarrayImplType = QarrayImplType.DENSE) -> Qarray:
208 """Identity matrix with the same shape as A.
210 Args:
211 A: Matrix.
212 implementation: Qarray implementation type, e.g. "sparse" or "dense".
214 Returns:
215 Identity matrix with the same shape as A.
216 """
217 space_dims = A.space_dims
218 total_dim = prod(space_dims)
219 return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims], implementation=implementation)
222def displace(N, α) -> Qarray:
223 """Displacement operator
225 Args:
226 N: Hilbert Space Size
227 α: Phase space displacement
229 Returns:
230 Displace operator D(α)
231 """
232 a = destroy(N)
233 return (α * a.dag() - jnp.conj(α) * a).expm()
235def squeeze(N, z):
236 """Single-mode Squeezing operator.
239 Args:
240 N: Hilbert Space Size
241 z: squeezing parameter
243 Returns:
244 Sqeezing operator
245 """
247 a = destroy(N)
248 op = (1 / 2.0) * jnp.conj(z) * (a @ a) - (1 / 2.0) * z * (a.dag() @ a.dag())
249 return op.expm()
252def squeezing_linear_to_dB(z):
253 return 20 * jnp.log10(jnp.exp(jnp.abs(z)))
255def squeezing_dB_to_linear(z_dB):
256 return jnp.log(10**(z_dB/20))
258# States ---------------------------------------------------------------------
261def basis(N: int, k: int, implementation: QarrayImplType = QarrayImplType.DENSE):
262 """Creates a |k> (i.e. fock state) ket in a specified Hilbert Space.
264 Args:
265 N: Hilbert space dimension
266 k: fock number
267 implementation: Qarray implementation type, e.g. "sparse" or "dense".
269 Returns:
270 Fock State |k>
271 """
272 return Qarray.create(one_hot(k, N).reshape(N, 1), implementation=implementation)
274def multi_mode_basis_set(Ns: List[int]) -> Qarray:
275 """Creates a multi-mode basis set.
277 Args:
278 Ns: List of Hilbert space dimensions for each mode.
280 Returns:
281 Multi-mode basis set.
282 """
283 data = jnp.eye(prod(Ns))
284 dims = (tuple(Ns), tuple([1 for _ in Ns]))
285 return Qarray.create(data, dims=dims, bdims=(prod(Ns),))
288def coherent(N: int, α: complex) -> Qarray:
289 """Coherent state.
291 Args:
292 N: Hilbert Space Size.
293 α: coherent state amplitude.
295 Return:
296 Coherent state |α⟩.
297 """
298 return displace(N, α) @ basis(N, 0)
301def thermal_dm(N: int, n: float) -> Qarray:
302 """Thermal state.
304 Args:
305 N: Hilbert Space Size.
306 n: average photon number.
308 Return:
309 Thermal state.
310 """
312 beta = jnp.log(1 + 1 / n)
314 return Qarray.create(
315 jnp.where(
316 jnp.isposinf(beta),
317 basis(N, 0).to_dm().data,
318 jnp.diag(jnp.exp(-beta * jnp.linspace(0, N - 1, N))),
319 )
320 ).unit()
323def basis_like(A: Qarray, ks: List[int]) -> Qarray:
324 """Creates a |k> (i.e. fock state) ket with the same space dims as A.
326 Args:
327 A: state or operator.
328 k: fock number.
330 Returns:
331 Fock State |k> with the same space dims as A.
332 """
333 space_dims = A.space_dims
334 assert len(space_dims) == len(ks), "len(ks) must be equal to len(space_dims)"
336 kets = []
337 for j, k in enumerate(ks):
338 kets.append(basis(space_dims[j], k))
339 return tensor(*kets)