Coverage for jaxquantum/core/operators.py: 100%
49 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""States."""
3from typing import List
4from jax import config
5from math import prod
7import jax.numpy as jnp
8from jax.nn import one_hot
10from jaxquantum.core.qarray import Qarray, tensor
12config.update("jax_enable_x64", True)
15def sigmax() -> Qarray:
16 """σx
18 Returns:
19 σx Pauli Operator
20 """
21 return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]]))
24def sigmay() -> Qarray:
25 """σy
27 Returns:
28 σy Pauli Operator
29 """
30 return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]]))
33def sigmaz() -> Qarray:
34 """σz
36 Returns:
37 σz Pauli Operator
38 """
39 return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]]))
42def hadamard() -> Qarray:
43 """H
45 Returns:
46 H: Hadamard gate
47 """
48 return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2))
51def sigmam() -> Qarray:
52 """σ-
54 Returns:
55 σ- Pauli Operator
56 """
57 return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]]))
60def sigmap() -> Qarray:
61 """σ+
63 Returns:
64 σ+ Pauli Operator
65 """
66 return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]]))
69def qubit_rotation(theta: float, nx, ny, nz) -> Qarray:
70 """Single qubit rotation.
72 Args:
73 theta: rotation angle.
74 nx: rotation axis x component.
75 ny: rotation axis y component.
76 nz: rotation axis z component.
78 Returns:
79 Single qubit rotation operator.
80 """
81 return jnp.cos(theta / 2) * identity(2) - 1j * jnp.sin(theta / 2) * (
82 nx * sigmax() + ny * sigmay() + nz * sigmaz()
83 )
86def destroy(N) -> Qarray:
87 """annihilation operator
89 Args:
90 N: Hilbert space size
92 Returns:
93 annilation operator in Hilber Space of size N
94 """
95 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1))
98def create(N) -> Qarray:
99 """creation operator
101 Args:
102 N: Hilbert space size
104 Returns:
105 creation operator in Hilber Space of size N
106 """
107 return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1))
110def num(N) -> Qarray:
111 """Number operator
113 Args:
114 N: Hilbert Space size
116 Returns:
117 number operator in Hilber Space of size N
118 """
119 return Qarray.create(jnp.diag(jnp.arange(N)))
122def identity(*args, **kwargs) -> Qarray:
123 """Identity matrix.
125 Returns:
126 Identity matrix.
127 """
128 return Qarray.create(jnp.eye(*args, **kwargs))
131def identity_like(A) -> Qarray:
132 """Identity matrix with the same shape as A.
134 Args:
135 A: Matrix.
137 Returns:
138 Identity matrix with the same shape as A.
139 """
140 space_dims = A.space_dims
141 total_dim = prod(space_dims)
142 return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims])
145def displace(N, α) -> Qarray:
146 """Displacement operator
148 Args:
149 N: Hilbert Space Size
150 α: Phase space displacement
152 Returns:
153 Displace operator D(α)
154 """
155 a = destroy(N)
156 return (α * a.dag() - jnp.conj(α) * a).expm()
159# States ---------------------------------------------------------------------
162def basis(N: int, k: int):
163 """Creates a |k> (i.e. fock state) ket in a specified Hilbert Space.
165 Args:
166 N: Hilbert space dimension
167 k: fock number
169 Returns:
170 Fock State |k>
171 """
172 return Qarray.create(one_hot(k, N).reshape(N, 1))
175def coherent(N: int, α: complex) -> Qarray:
176 """Coherent state.
178 Args:
179 N: Hilbert Space Size.
180 α: coherent state amplitude.
182 Return:
183 Coherent state |α⟩.
184 """
185 return displace(N, α) @ basis(N, 0)
188def thermal(N: int, beta: float) -> Qarray:
189 """Thermal state.
191 Args:
192 N: Hilbert Space Size.
193 beta: thermal state inverse temperature.
195 Return:
196 Thermal state.
197 """
199 return Qarray.create(
200 jnp.where(
201 jnp.isposinf(beta),
202 basis(N, 0).to_dm().data,
203 jnp.diag(jnp.exp(-beta * jnp.linspace(0, N - 1, N))),
204 )
205 ).unit()
208def basis_like(A: Qarray, ks: List[int]) -> Qarray:
209 """Creates a |k> (i.e. fock state) ket with the same space dims as A.
211 Args:
212 A: state or operator.
213 k: fock number.
215 Returns:
216 Fock State |k> with the same space dims as A.
217 """
218 space_dims = A.space_dims
219 assert len(space_dims) == len(ks), "len(ks) must be equal to len(space_dims)"
221 kets = []
222 for j, k in enumerate(ks):
223 kets.append(basis(space_dims[j], k))
224 return tensor(*kets)