Skip to content

oscillator

Oscillator gates.

Amp_Damp(N, err_prob, max_l)

Amplitude damping channel.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
err_prob

Error probability.

required
max_l

Maximum number of photons lost.

required

Returns:

Type Description

Amplitude damping channel.

Source code in jaxquantum/circuits/library/oscillator.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def Amp_Damp(N, err_prob, max_l):
    """Amplitude damping channel.

    Args:
        N: Hilbert space dimension.
        err_prob: Error probability.
        max_l: Maximum number of photons lost.

    Returns:
        Amplitude damping channel.
    """
    kmap = lambda params: Qarray.from_list(
        [_Ph_Loss_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)]
    )
    return Gate.create(
        N,
        name="Amp_Damp",
        params={"err_prob": err_prob, "max_l": max_l},
        gen_KM=kmap,
        num_modes=1,
    )

Amp_Gain(N, err_prob, max_l)

Amplitude gain channel.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
err_prob

Error probability.

required
max_l

Maximum number of photons gained.

required

Returns:

Type Description

Amplitude gain channel.

Source code in jaxquantum/circuits/library/oscillator.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def Amp_Gain(N, err_prob, max_l):
    """Amplitude gain channel.

    Args:
        N: Hilbert space dimension.
        err_prob: Error probability.
        max_l: Maximum number of photons gained.

    Returns:
        Amplitude gain channel.
    """
    kmap = lambda params: Qarray.from_list(
        [_Ph_Gain_Kraus_Op(N, err_prob, l) for l in range(max_l + 1)]
    )
    return Gate.create(
        N,
        name="Amp_Gain",
        params={"err_prob": err_prob, "max_l": max_l},
        gen_KM=kmap,
        num_modes=1,
    )

CD(N, beta, ts=None)

Conditional displacement gate.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
beta

Conditional displacement amplitude.

required
ts

Optional time sequence for hamiltonian simulation.

None

Returns:

Type Description

Conditional displacement gate.

Source code in jaxquantum/circuits/library/oscillator.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def CD(N, beta, ts=None):
    """Conditional displacement gate.

    Args:
        N: Hilbert space dimension.
        beta: Conditional displacement amplitude.
        ts: Optional time sequence for hamiltonian simulation.

    Returns:
        Conditional displacement gate.
    """
    g = basis(2, 0)
    e = basis(2, 1)

    gg = g @ g.dag()
    ee = e @ e.dag()

    gen_Ht = None
    if ts is not None:
        delta_t = ts[-1] - ts[0]
        amp = 1j * beta / delta_t / 2
        a = destroy(N)
        gen_Ht = lambda params: lambda t: (
            gg
            ^ (jnp.conj(amp) * a + amp * a.dag()) + ee
            ^ (jnp.conj(-amp) * a + (-amp) * a.dag())
        )

    return Gate.create(
        [2, N],
        name="CD",
        params={"beta": beta},
        gen_U=lambda params: (gg ^ displace(N, params["beta"] / 2))
        + (ee ^ displace(N, -params["beta"] / 2)),
        gen_Ht=gen_Ht,
        ts=ts,
        num_modes=2,
    )

CR(N, theta)

Conditional rotation gate.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
theta

Conditional rotation angle.

required

Returns:

Type Description

Conditional rotation gate.

Source code in jaxquantum/circuits/library/oscillator.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def CR(N, theta):
    """Conditional rotation gate.

    Args:
        N: Hilbert space dimension.
        theta: Conditional rotation angle.

    Returns:
        Conditional rotation gate.
    """
    g = basis(2, 0)
    e = basis(2, 1)

    gg = g @ g.dag()
    ee = e @ e.dag()


    return Gate.create(
        [2, N],
        name="CR",
        params={"theta": theta},
        gen_U=lambda params: (gg ^ (-1.j*theta/2*create(N)@destroy(N)).expm())
        + (ee ^ (1.j*theta/2*create(N)@destroy(N)).expm()),
        num_modes=2,
    )

D(N, alpha, ts=None, c_ops=None)

Displacement gate.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
alpha

Displacement amplitude.

required
ts

Optional time array for hamiltonian simulation.

None
c_ops

Optional collapse operators.

None

Returns:

Type Description

Displacement gate.

Source code in jaxquantum/circuits/library/oscillator.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def D(N, alpha, ts=None, c_ops=None):
    """Displacement gate.

    Args:
        N: Hilbert space dimension.
        alpha: Displacement amplitude.
        ts: Optional time array for hamiltonian simulation.
        c_ops: Optional collapse operators.

    Returns:
        Displacement gate.
    """
    gen_Ht = None
    if ts is not None:
        delta_t = ts[-1] - ts[0]
        amp = 1j * alpha / delta_t
        a = destroy(N)
        gen_Ht = lambda params: (lambda t: jnp.conj(amp) * a + amp * a.dag())

    return Gate.create(
        N,
        name="D",
        params={"alpha": alpha},
        gen_U=lambda params: displace(N, params["alpha"]),
        gen_Ht=gen_Ht,
        ts=ts,
        gen_c_ops=lambda params: Qarray.from_list([]) if c_ops is None else c_ops,
        num_modes=1,
    )

Dephasing_Ch(N, err_prob, max_l)

Dephasing channel.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
err_prob

Error probability.

required
max_l

Maximum number of kraus operators.

required

Returns:

Type Description

Dephasing channel.

Source code in jaxquantum/circuits/library/oscillator.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def Dephasing_Ch(N, err_prob, max_l):
    """Dephasing channel.

    Args:
        N: Hilbert space dimension.
        err_prob: Error probability.
        max_l: Maximum number of kraus operators.

    Returns:
        Dephasing channel.
    """

    xs, ws = hermgauss(max_l)
    phis = jnp.sqrt(2*err_prob)*xs
    ws = 1/jnp.sqrt(jnp.pi)*ws

    kmap = lambda params: Qarray.from_list(
        [_Dephasing_Kraus_Op(N, w, phi) for (w, phi) in zip(ws, phis)]
    )
    return Gate.create(
        N,
        name="Amp_Gain",
        params={"err_prob": err_prob, "max_l": max_l},
        gen_KM=kmap,
        num_modes=1,
    )

Dephasing_Reset(N, p, t_rst, chi, max_l)

Dephasing due to imperfect reset between a qubit and a resonator.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
p

Reset error probability.

required
t_rst

Reset time.

required
chi

Dephasing strength.

required
max_l

Maximum number of operators.

required

Returns:

Type Description

Dephasing due to reset channel.

Source code in jaxquantum/circuits/library/oscillator.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def Dephasing_Reset(N, p, t_rst, chi, max_l):
    """Dephasing due to imperfect reset between a qubit and a resonator.

    Args:
        N: Hilbert space dimension.
        p: Reset error probability.
        t_rst: Reset time.
        chi: Dephasing strength.
        max_l: Maximum number of operators.

    Returns:
        Dephasing due to reset channel.
    """

    kmap = lambda params: Qarray.from_list(
        [_Reset_Deph_Kraus_Op(N, p, t_rst, chi, l, max_l) for l in
         range(max_l)]
    )
    return Gate.create(
        [2, N],
        name="Dephasing_Reset",
        params={"p": p, "t_rst": t_rst, "chi": chi, "max_l": max_l},
        gen_KM=kmap,
        num_modes=2,
    )

Thermal_Ch(N, err_prob, n_bar, max_l)

Thermal channel.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
err_prob

Error probability.

required
n_bar

Average photon number.

required
max_l

Maximum number of photons gained/lost.

required

Returns:

Type Description

Thermal channel.

Source code in jaxquantum/circuits/library/oscillator.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def Thermal_Ch(N, err_prob, n_bar, max_l):
    """Thermal channel.

    Args:
        N: Hilbert space dimension.
        err_prob: Error probability.
        n_bar: Average photon number.
        max_l: Maximum number of photons gained/lost.

    Returns:
        Thermal channel.
    """
    kmap = lambda params: Qarray.from_list(
        [
            _Thermal_Kraus_Op(N, err_prob, n_bar, l, k)
            for l in range(max_l + 1)
            for k in range(max_l + 1)
        ]
    )
    return Gate.create(
        N,
        name="Thermal_Ch",
        params={"err_prob": err_prob, "n_bar": n_bar, "max_l": max_l},
        gen_KM=kmap,
        num_modes=1,
    )

selfKerr(N, K)

Self-Kerr interaction gate.

Parameters:

Name Type Description Default
N

Hilbert space dimension.

required
K

Kerr coefficient.

required

Returns:

Type Description

Self-Kerr gate.

Source code in jaxquantum/circuits/library/oscillator.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def selfKerr(N, K):
    """Self-Kerr interaction gate.

    Args:
        N: Hilbert space dimension.
        K: Kerr coefficient.

    Returns:
        Self-Kerr gate.
    """
    a = destroy(N)
    return Gate.create(
        N,
        name="selfKerr",
        params={"Kerr": K},
        gen_U=lambda params: (-1.0j * K / 2 * (a.dag() @ a.dag() @ a @ a)).expm(),
        num_modes=1,
    )