Skip to content

operators

States.

basis(N, k)

Creates a |k> (i.e. fock state) ket in a specified Hilbert Space.

Parameters:

Name Type Description Default
N int

Hilbert space dimension

required
k int

fock number

required

Returns:

Type Description

Fock State |k>

Source code in jaxquantum/core/operators.py
162
163
164
165
166
167
168
169
170
171
172
def basis(N: int, k: int):
    """Creates a |k> (i.e. fock state) ket in a specified Hilbert Space.

    Args:
        N: Hilbert space dimension
        k: fock number

    Returns:
        Fock State |k>
    """
    return Qarray.create(one_hot(k, N).reshape(N, 1))

basis_like(A, ks)

Creates a |k> (i.e. fock state) ket with the same space dims as A.

Parameters:

Name Type Description Default
A Qarray

state or operator.

required
k

fock number.

required

Returns:

Type Description
Qarray

Fock State |k> with the same space dims as A.

Source code in jaxquantum/core/operators.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def basis_like(A: Qarray, ks: List[int]) -> Qarray:
    """Creates a |k> (i.e. fock state) ket with the same space dims as A.

    Args:
        A: state or operator.
        k: fock number.

    Returns:
        Fock State |k> with the same space dims as A.
    """
    space_dims = A.space_dims
    assert len(space_dims) == len(ks), "len(ks) must be equal to len(space_dims)"

    kets = []
    for j, k in enumerate(ks):
        kets.append(basis(space_dims[j], k))
    return tensor(*kets)

coherent(N, α)

Coherent state.

Parameters:

Name Type Description Default
N int

Hilbert Space Size.

required
α complex

coherent state amplitude.

required
Return

Coherent state |α⟩.

Source code in jaxquantum/core/operators.py
175
176
177
178
179
180
181
182
183
184
185
def coherent(N: int, α: complex) -> Qarray:
    """Coherent state.

    Args:
        N: Hilbert Space Size.
        α: coherent state amplitude.

    Return:
        Coherent state |α⟩.
    """
    return displace(N, α) @ basis(N, 0)

create(N)

creation operator

Parameters:

Name Type Description Default
N

Hilbert space size

required

Returns:

Type Description
Qarray

creation operator in Hilber Space of size N

Source code in jaxquantum/core/operators.py
 98
 99
100
101
102
103
104
105
106
107
def create(N) -> Qarray:
    """creation operator

    Args:
        N: Hilbert space size

    Returns:
        creation operator in Hilber Space of size N
    """
    return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=-1))

destroy(N)

annihilation operator

Parameters:

Name Type Description Default
N

Hilbert space size

required

Returns:

Type Description
Qarray

annilation operator in Hilber Space of size N

Source code in jaxquantum/core/operators.py
86
87
88
89
90
91
92
93
94
95
def destroy(N) -> Qarray:
    """annihilation operator

    Args:
        N: Hilbert space size

    Returns:
        annilation operator in Hilber Space of size N
    """
    return Qarray.create(jnp.diag(jnp.sqrt(jnp.arange(1, N)), k=1))

displace(N, α)

Displacement operator

Parameters:

Name Type Description Default
N

Hilbert Space Size

required
α

Phase space displacement

required

Returns:

Type Description
Qarray

Displace operator D(α)

Source code in jaxquantum/core/operators.py
145
146
147
148
149
150
151
152
153
154
155
156
def displace(N, α) -> Qarray:
    """Displacement operator

    Args:
        N: Hilbert Space Size
        α: Phase space displacement

    Returns:
        Displace operator D(α)
    """
    a = destroy(N)
    return (α * a.dag() - jnp.conj(α) * a).expm()

hadamard()

H

Returns:

Name Type Description
H Qarray

Hadamard gate

Source code in jaxquantum/core/operators.py
42
43
44
45
46
47
48
def hadamard() -> Qarray:
    """H

    Returns:
        H: Hadamard gate
    """
    return Qarray.create(jnp.array([[1, 1], [1, -1]]) / jnp.sqrt(2))

identity(*args, **kwargs)

Identity matrix.

Returns:

Type Description
Qarray

Identity matrix.

Source code in jaxquantum/core/operators.py
122
123
124
125
126
127
128
def identity(*args, **kwargs) -> Qarray:
    """Identity matrix.

    Returns:
        Identity matrix.
    """
    return Qarray.create(jnp.eye(*args, **kwargs))

identity_like(A)

Identity matrix with the same shape as A.

Parameters:

Name Type Description Default
A

Matrix.

required

Returns:

Type Description
Qarray

Identity matrix with the same shape as A.

Source code in jaxquantum/core/operators.py
131
132
133
134
135
136
137
138
139
140
141
142
def identity_like(A) -> Qarray:
    """Identity matrix with the same shape as A.

    Args:
        A: Matrix.

    Returns:
        Identity matrix with the same shape as A.
    """
    space_dims = A.space_dims
    total_dim = prod(space_dims)
    return Qarray.create(jnp.eye(total_dim, total_dim), dims=[space_dims, space_dims])

num(N)

Number operator

Parameters:

Name Type Description Default
N

Hilbert Space size

required

Returns:

Type Description
Qarray

number operator in Hilber Space of size N

Source code in jaxquantum/core/operators.py
110
111
112
113
114
115
116
117
118
119
def num(N) -> Qarray:
    """Number operator

    Args:
        N: Hilbert Space size

    Returns:
        number operator in Hilber Space of size N
    """
    return Qarray.create(jnp.diag(jnp.arange(N)))

qubit_rotation(theta, nx, ny, nz)

Single qubit rotation.

Parameters:

Name Type Description Default
theta float

rotation angle.

required
nx

rotation axis x component.

required
ny

rotation axis y component.

required
nz

rotation axis z component.

required

Returns:

Type Description
Qarray

Single qubit rotation operator.

Source code in jaxquantum/core/operators.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def qubit_rotation(theta: float, nx, ny, nz) -> Qarray:
    """Single qubit rotation.

    Args:
        theta: rotation angle.
        nx: rotation axis x component.
        ny: rotation axis y component.
        nz: rotation axis z component.

    Returns:
        Single qubit rotation operator.
    """
    return jnp.cos(theta / 2) * identity(2) - 1j * jnp.sin(theta / 2) * (
        nx * sigmax() + ny * sigmay() + nz * sigmaz()
    )

sigmam()

σ-

Returns:

Type Description
Qarray

σ- Pauli Operator

Source code in jaxquantum/core/operators.py
51
52
53
54
55
56
57
def sigmam() -> Qarray:
    """σ-

    Returns:
        σ- Pauli Operator
    """
    return Qarray.create(jnp.array([[0.0, 0.0], [1.0, 0.0]]))

sigmap()

σ+

Returns:

Type Description
Qarray

σ+ Pauli Operator

Source code in jaxquantum/core/operators.py
60
61
62
63
64
65
66
def sigmap() -> Qarray:
    """σ+

    Returns:
        σ+ Pauli Operator
    """
    return Qarray.create(jnp.array([[0.0, 1.0], [0.0, 0.0]]))

sigmax()

σx

Returns:

Type Description
Qarray

σx Pauli Operator

Source code in jaxquantum/core/operators.py
15
16
17
18
19
20
21
def sigmax() -> Qarray:
    """σx

    Returns:
        σx Pauli Operator
    """
    return Qarray.create(jnp.array([[0.0, 1.0], [1.0, 0.0]]))

sigmay()

σy

Returns:

Type Description
Qarray

σy Pauli Operator

Source code in jaxquantum/core/operators.py
24
25
26
27
28
29
30
def sigmay() -> Qarray:
    """σy

    Returns:
        σy Pauli Operator
    """
    return Qarray.create(jnp.array([[0.0, -1.0j], [1.0j, 0.0]]))

sigmaz()

σz

Returns:

Type Description
Qarray

σz Pauli Operator

Source code in jaxquantum/core/operators.py
33
34
35
36
37
38
39
def sigmaz() -> Qarray:
    """σz

    Returns:
        σz Pauli Operator
    """
    return Qarray.create(jnp.array([[1.0, 0.0], [0.0, -1.0]]))

thermal(N, beta)

Thermal state.

Parameters:

Name Type Description Default
N int

Hilbert Space Size.

required
beta float

thermal state inverse temperature.

required
Return

Thermal state.

Source code in jaxquantum/core/operators.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def thermal(N: int, beta: float) -> Qarray:
    """Thermal state.

    Args:
        N: Hilbert Space Size.
        beta: thermal state inverse temperature.

    Return:
        Thermal state.
    """

    return Qarray.create(
        jnp.where(
            jnp.isposinf(beta),
            basis(N, 0).to_dm().data,
            jnp.diag(jnp.exp(-beta * jnp.linspace(0, N - 1, N))),
        )
    ).unit()