Coverage for jaxquantum/core/visualization.py: 91%
158 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""
2Visualization utils.
3"""
5import matplotlib.pyplot as plt
7from jaxquantum.core.qp_distributions import wigner, qfunc
8from jaxquantum.core.cfunctions import cf_wigner
9import jax.numpy as jnp
10import numpy as np
12WIGNER = "wigner"
13HUSIMI = "husimi"
16def plot_qp(
17 state,
18 pts_x,
19 pts_y=None,
20 g=2,
21 axs=None,
22 contour=True,
23 qp_type=WIGNER,
24 cbar_label="",
25 axis_scale_factor=1,
26 plot_cbar=True,
27 x_ticks=None,
28 y_ticks=None,
29 z_ticks=None,
30 subtitles=None,
31 figtitle=None,
32):
33 """Plot quasi-probability distribution.
36 Args:
37 state: state with arbitrary number of batch dimensions, result will
38 be flattened to a 2d grid to allow for plotting
39 pts_x: x points to evaluate quasi-probability distribution at
40 pts_y: y points to evaluate quasi-probability distribution at
41 g : float, default: 2
42 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
43 related to the value of :math:`\\hbar` in the commutation relation
44 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`.
45 axs: matplotlib axes to plot on
46 contour: make the plot use contouring
47 qp_type: type of quasi probability distribution ("wigner", "qfunc")
48 cbar_label: label for the cbar
49 axis_scale_factor: scale of the axes labels relative
50 plot_cbar: whether to plot cbar
51 x_ticks: tick position for the x-axis
52 y_ticks: tick position for the y-axis
53 z_ticks: tick position for the z-axis
54 subtitles: subtitles for the subplots
55 figtitle: figure title
57 Returns:
58 axis on which the plot was plotted.
59 """
60 if pts_y is None:
61 pts_y = pts_x
62 pts_x = jnp.array(pts_x)
63 pts_y = jnp.array(pts_y)
65 if len(state.bdims)==1 and state.bdims[0]==1:
66 state = state[0]
69 bdims = state.bdims
70 added_baxes = 0
72 if subtitles is not None:
73 if subtitles.shape != bdims:
74 raise ValueError(
75 f"labels must have same shape as bdims, "
76 f"got shapes {subtitles.shape} and {bdims}"
77 )
79 if len(bdims) == 0:
80 bdims = (1,)
81 added_baxes += 1
82 if len(bdims) == 1:
83 bdims = (1, bdims[0])
84 added_baxes += 1
86 extra_dims = bdims[2:]
87 if extra_dims != ():
88 state = state.reshape_bdims(
89 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
90 )
91 if subtitles is not None:
92 subtitles = subtitles.reshape(
93 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
94 )
95 bdims = state.bdims
97 if axs is None:
98 _, axs = plt.subplots(
99 bdims[0],
100 bdims[1],
101 figsize=(4 * bdims[1], 3 * bdims[0]),
102 dpi=200,
103 )
105 if qp_type == WIGNER:
106 vmin = -1
107 vmax = 1
108 scale = np.pi / 2
109 cmap = "seismic"
110 cbar_label = r"$\mathcal{W}(\alpha)$"
111 QP = scale * wigner(state, pts_x, pts_y, g=g)
113 elif qp_type == HUSIMI:
114 vmin = 0
115 vmax = 1
116 scale = np.pi
117 cmap = "jet"
118 cbar_label = r"$\mathcal{Q}(\alpha)$"
119 QP = scale * qfunc(state, pts_x, pts_y, g=g)
123 for _ in range(added_baxes):
124 QP = jnp.array([QP])
125 axs = np.array([axs])
126 if subtitles is not None:
127 subtitles = np.array([subtitles])
132 pts_x = pts_x * axis_scale_factor
133 pts_y = pts_y * axis_scale_factor
135 x_ticks = (
136 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x), 5) if x_ticks is None else x_ticks
137 )
138 y_ticks = (
139 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y), 5) if y_ticks is None else y_ticks
140 )
141 z_ticks = jnp.linspace(vmin, vmax, 11) if z_ticks is None else z_ticks
143 for row in range(bdims[0]):
144 for col in range(bdims[1]):
145 ax = axs[row, col]
146 if contour:
147 im = ax.contourf(
148 pts_x,
149 pts_y,
150 QP[row, col],
151 cmap=cmap,
152 vmin=vmin,
153 vmax=vmax,
154 levels=np.linspace(vmin, vmax, 101),
155 )
156 else:
157 im = ax.pcolormesh(
158 pts_x,
159 pts_y,
160 QP[row, col],
161 cmap=cmap,
162 vmin=vmin,
163 vmax=vmax,
164 )
165 ax.set_xticks(x_ticks)
166 ax.set_yticks(y_ticks)
167 ax.axhline(0, linestyle="-", color="black", alpha=0.7)
168 ax.axvline(0, linestyle="-", color="black", alpha=0.7)
169 ax.grid()
170 ax.set_aspect("equal", adjustable="box")
172 if plot_cbar:
173 cbar = plt.colorbar(
174 im, ax=ax, orientation="vertical", ticks=np.linspace(-1, 1, 11)
175 )
176 cbar.ax.set_title(cbar_label)
177 cbar.set_ticks(z_ticks)
179 ax.set_xlabel(r"Re[$\alpha$]")
180 ax.set_ylabel(r"Im[$\alpha$]")
181 if subtitles is not None:
182 ax.set_title(subtitles[row, col])
184 fig = ax.get_figure()
185 fig.tight_layout()
186 if figtitle is not None:
187 fig.suptitle(figtitle, y=1.04)
188 return axs, im
191def plot_wigner(
192 state,
193 pts_x,
194 pts_y=None,
195 g=2,
196 axs=None,
197 contour=True,
198 cbar_label="",
199 axis_scale_factor=1,
200 plot_cbar=True,
201 x_ticks=None,
202 y_ticks=None,
203 z_ticks=None,
204 subtitles=None,
205 figtitle=None,
206):
207 """Plot the wigner function of the state.
210 Args:
211 state: state with arbitrary number of batch dimensions, result will
212 be flattened to a 2d grid to allow for plotting
213 pts_x: x points to evaluate quasi-probability distribution at
214 pts_y: y points to evaluate quasi-probability distribution at
215 g : float, default: 2
216 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
217 related to the value of :math:`\\hbar` in the commutation relation
218 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`.
219 axs: matplotlib axes to plot on
220 contour: make the plot use contouring
221 cbar_label: label for the cbar
222 axis_scale_factor: scale of the axes labels relative
223 plot_cbar: whether to plot cbar
224 x_ticks: tick position for the x-axis
225 y_ticks: tick position for the y-axis
226 z_ticks: tick position for the z-axis
227 subtitles: subtitles for the subplots
228 figtitle: figure title
230 Returns:
231 axis on which the plot was plotted.
232 """
233 return plot_qp(
234 state=state,
235 pts_x=pts_x,
236 pts_y=pts_y,
237 g=g,
238 axs=axs,
239 contour=contour,
240 qp_type=WIGNER,
241 cbar_label=cbar_label,
242 axis_scale_factor=axis_scale_factor,
243 plot_cbar=plot_cbar,
244 x_ticks=x_ticks,
245 y_ticks=y_ticks,
246 z_ticks=z_ticks,
247 subtitles=subtitles,
248 figtitle=figtitle,
249 )
252def plot_qfunc(
253 state,
254 pts_x,
255 pts_y=None,
256 g=2,
257 axs=None,
258 contour=True,
259 cbar_label="",
260 axis_scale_factor=1,
261 plot_cbar=True,
262 x_ticks=None,
263 y_ticks=None,
264 z_ticks=None,
265 subtitles=None,
266 figtitle=None,
267):
268 """Plot the husimi function of the state.
271 Args:
272 state: state with arbitrary number of batch dimensions, result will
273 be flattened to a 2d grid to allow for plotting
274 pts_x: x points to evaluate quasi-probability distribution at
275 pts_y: y points to evaluate quasi-probability distribution at
276 g : float, default: 2
277 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
278 related to the value of :math:`\\hbar` in the commutation relation
279 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`.
280 axs: matplotlib axes to plot on
281 contour: make the plot use contouring
282 cbar_label: label for the cbar
283 axis_scale_factor: scale of the axes labels relative
284 plot_cbar: whether to plot cbar
285 x_ticks: tick position for the x-axis
286 y_ticks: tick position for the y-axis
287 z_ticks: tick position for the z-axis
288 subtitles: subtitles for the subplots
289 figtitle: figure title
291 Returns:
292 axis on which the plot was plotted.
293 """
294 return plot_qp(
295 state=state,
296 pts_x=pts_x,
297 pts_y=pts_y,
298 g=g,
299 axs=axs,
300 contour=contour,
301 qp_type=HUSIMI,
302 cbar_label=cbar_label,
303 axis_scale_factor=axis_scale_factor,
304 plot_cbar=plot_cbar,
305 x_ticks=x_ticks,
306 y_ticks=y_ticks,
307 z_ticks=z_ticks,
308 subtitles=subtitles,
309 figtitle=figtitle,
310 )
313def plot_cf(
314 state,
315 pts_x,
316 pts_y=None,
317 axs=None,
318 contour=True,
319 qp_type=WIGNER,
320 cbar_label="",
321 axis_scale_factor=1,
322 plot_cbar=True,
323 x_ticks=None,
324 y_ticks=None,
325 z_ticks=None,
326 subtitles=None,
327 figtitle=None,
328):
329 """Plot characteristic function.
332 Args:
333 state: state with arbitrary number of batch dimensions, result will
334 be flattened to a 2d grid to allow for plotting
335 pts_x: x points to evaluate quasi-probability distribution at
336 pts_y: y points to evaluate quasi-probability distribution at
337 axs: matplotlib axes to plot on
338 contour: make the plot use contouring
339 qp_type: type of quasi probability distribution ("wigner")
340 cbar_label: labels for the real and imaginary cbar
341 axis_scale_factor: scale of the axes labels relative
342 plot_cbar: whether to plot cbar
343 x_ticks: tick position for the x-axis
344 y_ticks: tick position for the y-axis
345 z_ticks: tick position for the z-axis
346 subtitles: subtitles for the subplots
347 figtitle: figure title
349 Returns:
350 axis on which the plot was plotted.
351 """
352 if pts_y is None:
353 pts_y = pts_x
354 pts_x = jnp.array(pts_x)
355 pts_y = jnp.array(pts_y)
357 bdims = state.bdims
358 added_baxes = 0
360 if subtitles is not None:
361 if subtitles.shape != bdims:
362 raise ValueError(
363 f"labels must have same shape as bdims, "
364 f"got shapes {subtitles.shape} and {bdims}"
365 )
367 if len(bdims) == 0:
368 bdims = (1,)
369 added_baxes += 1
370 if len(bdims) == 1:
371 bdims = (1, bdims[0])
372 added_baxes += 1
374 extra_dims = bdims[2:]
375 if extra_dims != ():
376 state = state.reshape_bdims(
377 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
378 )
379 if subtitles is not None:
380 subtitles = subtitles.reshape(
381 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
382 )
383 bdims = state.bdims
385 if axs is None:
386 _, axs = plt.subplots(
387 bdims[0],
388 bdims[1]*2,
389 figsize=(4 * bdims[1]*2, 3 * bdims[0]),
390 dpi=200,
391 )
394 if qp_type == WIGNER:
395 vmin = -1
396 vmax = 1
397 scale = 1
398 cmap = "seismic"
399 cbar_label = [r"$\mathcal{Re}(\chi_W(\alpha))$", r"$\mathcal{"
400 r"Im}(\chi_W("
401 r"\alpha))$"]
402 QP = scale * cf_wigner(state, pts_x, pts_y)
404 for _ in range(added_baxes):
405 QP = jnp.array([QP])
406 axs = np.array([axs])
407 if subtitles is not None:
408 subtitles = np.array([subtitles])
410 if added_baxes==2:
411 axs = axs[0] # When the input state is zero-dimensional, remove an
412 # axis that is automatically added due to the subcolumns
415 pts_x = pts_x * axis_scale_factor
416 pts_y = pts_y * axis_scale_factor
418 x_ticks = (
419 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x),
420 5) if x_ticks is None else x_ticks
421 )
422 y_ticks = (
423 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y),
424 5) if y_ticks is None else y_ticks
425 )
426 z_ticks = jnp.linspace(vmin, vmax, 11) if z_ticks is None else z_ticks
427 print(axs.shape)
428 for row in range(bdims[0]):
429 for col in range(bdims[1]):
430 for subcol in range(2):
431 ax = axs[row, 2 * col + subcol]
432 if contour:
433 im = ax.contourf(
434 pts_x,
435 pts_y,
436 jnp.real(QP[row, col]) if subcol==0 else jnp.imag(QP[
437 row, col]),
438 cmap=cmap,
439 vmin=vmin,
440 vmax=vmax,
441 levels=np.linspace(vmin, vmax, 101),
442 )
443 else:
444 im = ax.pcolormesh(
445 pts_x,
446 pts_y,
447 jnp.real(QP[row, col]) if subcol == 0 else jnp.imag(QP[
448 row, col]),
449 cmap=cmap,
450 vmin=vmin,
451 vmax=vmax,
452 )
453 ax.set_xticks(x_ticks)
454 ax.set_yticks(y_ticks)
455 ax.axhline(0, linestyle="-", color="black", alpha=0.7)
456 ax.axvline(0, linestyle="-", color="black", alpha=0.7)
457 ax.grid()
458 ax.set_aspect("equal", adjustable="box")
460 if plot_cbar:
461 cbar = plt.colorbar(
462 im, ax=ax, orientation="vertical",
463 ticks=np.linspace(-1, 1, 11)
464 )
465 cbar.ax.set_title(cbar_label[subcol])
466 cbar.set_ticks(z_ticks)
468 ax.set_xlabel(r"Re[$\alpha$]")
469 ax.set_ylabel(r"Im[$\alpha$]")
470 if subtitles is not None:
471 ax.set_title(subtitles[row, col])
473 fig = ax.get_figure()
474 fig.tight_layout()
475 if figtitle is not None:
476 fig.suptitle(figtitle, y=1.04)
477 return axs, im
479def plot_cf_wigner(
480 state,
481 pts_x,
482 pts_y=None,
483 axs=None,
484 contour=True,
485 cbar_label="",
486 axis_scale_factor=1,
487 plot_cbar=True,
488 x_ticks=None,
489 y_ticks=None,
490 z_ticks=None,
491 subtitles=None,
492 figtitle=None,
493):
494 """Plot the Wigner characteristic function of the state.
497 Args:
498 state: state with arbitrary number of batch dimensions, result will
499 be flattened to a 2d grid to allow for plotting
500 pts_x: x points to evaluate quasi-probability distribution at
501 pts_y: y points to evaluate quasi-probability distribution at
502 axs: matplotlib axes to plot on
503 contour: make the plot use contouring
504 cbar_label: label for the cbar
505 axis_scale_factor: scale of the axes labels relative
506 plot_cbar: whether to plot cbar
507 x_ticks: tick position for the x-axis
508 y_ticks: tick position for the y-axis
509 z_ticks: tick position for the z-axis
510 subtitles: subtitles for the subplots
511 figtitle: figure title
513 Returns:
514 axis on which the plot was plotted.
515 """
516 return plot_cf(
517 state=state,
518 pts_x=pts_x,
519 pts_y=pts_y,
520 axs=axs,
521 contour=contour,
522 qp_type=WIGNER,
523 cbar_label=cbar_label,
524 axis_scale_factor=axis_scale_factor,
525 plot_cbar=plot_cbar,
526 x_ticks=x_ticks,
527 y_ticks=y_ticks,
528 z_ticks=z_ticks,
529 subtitles=subtitles,
530 figtitle=figtitle,
531 )