Coverage for jaxquantum/core/visualization.py: 22%
40 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""
2Visualization utils.
3"""
5import qutip as qt
6import numpy as np
7import matplotlib.pyplot as plt
9from jaxquantum.core.conversions import jqt2qt
11WIGNER = "wigner"
12QFUNC = "qfunc"
15def plot_qp(
16 state,
17 pts,
18 ax=None,
19 contour=True,
20 qp_type=WIGNER,
21 cbar_label="",
22 axis_scale_factor=1,
23 plot_cbar=True,
24):
25 """Plot quasi-probability distribution.
27 TODO: decouple this from qutip.
29 Args:
30 state: statevector
31 pts: points to evaluate quasi-probability distribution on
32 dim: dimensions of state
33 ax: matplotlib axis to plot on
34 contour: make the plot use contouring
35 qp_type: type of quasi probability distribution ("wigner", "qfunc")
37 Returns:
38 axis on which the plot was plotted.
39 """
40 pts = np.array(pts)
41 state = jqt2qt(state)
42 if ax is None:
43 _, ax = plt.subplots(1, figsize=(4, 3), dpi=200)
44 # fig = ax.get_figure()
46 if qp_type == WIGNER:
47 vmin = -1
48 vmax = 1
49 scale = np.pi / 2
50 cmap = "seismic"
51 elif qp_type == QFUNC:
52 vmin = 0
53 vmax = 1
54 scale = np.pi
55 cmap = "jet"
57 QP = scale * getattr(qt, qp_type)(state, pts, pts, g=2)
59 pts = pts * axis_scale_factor
61 if contour:
62 im = ax.contourf(
63 pts,
64 pts,
65 QP,
66 cmap=cmap,
67 vmin=vmin,
68 vmax=vmax,
69 levels=np.linspace(vmin, vmax, 101),
70 )
71 else:
72 im = ax.pcolormesh(
73 pts,
74 pts,
75 QP,
76 cmap=cmap,
77 vmin=vmin,
78 vmax=vmax,
79 )
80 ax.axhline(0, linestyle="-", color="black", alpha=0.7)
81 ax.axvline(0, linestyle="-", color="black", alpha=0.7)
82 ax.grid()
83 ax.set_aspect("equal", adjustable="box")
85 if plot_cbar:
86 cbar = plt.colorbar(
87 im, ax=ax, orientation="vertical", ticks=np.linspace(-1, 1, 11)
88 )
89 cbar.ax.set_title(cbar_label)
91 ax.set_xlabel(r"Re[$\alpha$]")
92 ax.set_ylabel(r"Im[$\alpha$]")
94 fig = ax.get_figure()
95 fig.tight_layout()
96 return ax, im
99plot_wigner = lambda state, pts, ax=None, contour=True, **kwargs: plot_qp(
100 state,
101 pts,
102 ax=ax,
103 contour=contour,
104 qp_type=WIGNER,
105 cbar_label=r"$\mathcal{W}(\alpha)$",
106 **kwargs,
107)
109plot_qfunc = lambda state, pts, ax=None, contour=True, **kwargs: plot_qp(
110 state,
111 pts,
112 ax=ax,
113 contour=contour,
114 qp_type=QFUNC,
115 cbar_label=r"$\mathcal{Q}(\alpha)$",
116 **kwargs,
117)