Coverage for jaxquantum / core / visualization.py: 91%
157 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 20:38 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-12-03 20:38 +0000
1"""
2Visualization utils.
3"""
5import matplotlib.pyplot as plt
7from jaxquantum.core.qp_distributions import wigner, qfunc
8from jaxquantum.core.cfunctions import cf_wigner
9import jax.numpy as jnp
10import numpy as np
12WIGNER = "wigner"
13HUSIMI = "husimi"
16def plot_qp(
17 state,
18 pts_x,
19 pts_y=None,
20 g=2,
21 axs=None,
22 contour=True,
23 qp_type=WIGNER,
24 cbar_label="",
25 axis_scale_factor=1,
26 plot_cbar=True,
27 x_ticks=None,
28 y_ticks=None,
29 z_ticks=None,
30 subtitles=None,
31 figtitle=None,
32):
33 """Plot quasi-probability distribution.
36 Args:
37 state: state with arbitrary number of batch dimensions, result will
38 be flattened to a 2d grid to allow for plotting
39 pts_x: x points to evaluate quasi-probability distribution at
40 pts_y: y points to evaluate quasi-probability distribution at
41 g : float, default: 2
42 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
43 related to the value of :math:`\\hbar` in the commutation relation
44 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`.
45 axs: matplotlib axes to plot on
46 contour: make the plot use contouring
47 qp_type: type of quasi probability distribution ("wigner", "qfunc")
48 cbar_label: label for the cbar
49 axis_scale_factor: scale of the axes labels relative
50 plot_cbar: whether to plot cbar
51 x_ticks: tick position for the x-axis
52 y_ticks: tick position for the y-axis
53 z_ticks: tick position for the z-axis
54 subtitles: subtitles for the subplots
55 figtitle: figure title
57 Returns:
58 axis on which the plot was plotted.
59 """
60 if pts_y is None:
61 pts_y = pts_x
62 pts_x = jnp.array(pts_x)
63 pts_y = jnp.array(pts_y)
65 if len(state.bdims)==1 and state.bdims[0]==1:
66 state = state[0]
69 bdims = state.bdims
70 added_baxes = 0
72 if subtitles is not None:
73 if subtitles.shape != bdims:
74 raise ValueError(
75 f"labels must have same shape as bdims, "
76 f"got shapes {subtitles.shape} and {bdims}"
77 )
79 if len(bdims) == 0:
80 bdims = (1,)
81 added_baxes += 1
82 if len(bdims) == 1:
83 bdims = (1, bdims[0])
84 added_baxes += 1
86 extra_dims = bdims[2:]
87 if extra_dims != ():
88 state = state.reshape_bdims(
89 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
90 )
91 if subtitles is not None:
92 subtitles = subtitles.reshape(
93 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
94 )
95 bdims = state.bdims
97 if axs is None:
98 _, axs = plt.subplots(
99 bdims[0],
100 bdims[1],
101 figsize=(4 * bdims[1], 3 * bdims[0]),
102 dpi=200,
103 )
105 if qp_type == WIGNER:
106 vmin = -1
107 vmax = 1
108 scale = np.pi / 2
109 cmap = "seismic"
110 cbar_label = r"$\mathcal{W}(\alpha)$"
111 QP = scale * wigner(state, pts_x, pts_y, g=g)
113 elif qp_type == HUSIMI:
114 vmin = 0
115 vmax = 1
116 scale = np.pi
117 cmap = "jet"
118 cbar_label = r"$\mathcal{Q}(\alpha)$"
119 QP = scale * qfunc(state, pts_x, pts_y, g=g)
123 for _ in range(added_baxes):
124 QP = jnp.array([QP])
125 axs = np.array([axs])
126 if subtitles is not None:
127 subtitles = np.array([subtitles])
132 pts_x = pts_x * axis_scale_factor
133 pts_y = pts_y * axis_scale_factor
135 x_ticks = (
136 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x), 5) if x_ticks is None else x_ticks
137 )
138 y_ticks = (
139 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y), 5) if y_ticks is None else y_ticks
140 )
141 z_ticks = jnp.linspace(vmin, vmax, 3) if z_ticks is None else z_ticks
143 for row in range(bdims[0]):
144 for col in range(bdims[1]):
145 ax = axs[row, col]
146 if contour:
147 im = ax.contourf(
148 pts_x,
149 pts_y,
150 QP[row, col],
151 cmap=cmap,
152 vmin=vmin,
153 vmax=vmax,
154 levels=np.linspace(vmin, vmax, 101),
155 )
156 else:
157 im = ax.pcolormesh(
158 pts_x,
159 pts_y,
160 QP[row, col],
161 cmap=cmap,
162 vmin=vmin,
163 vmax=vmax,
164 )
165 ax.set_xticks(x_ticks)
166 ax.set_yticks(y_ticks)
167 ax.axhline(0, linestyle="-", color="black", alpha=0.7)
168 ax.axvline(0, linestyle="-", color="black", alpha=0.7)
169 ax.grid()
170 ax.set_aspect("equal", adjustable="box")
172 if plot_cbar:
173 cbar = plt.colorbar(
174 im, ax=ax, orientation="vertical", ticks=np.linspace(-1, 1, 11)
175 )
176 cbar.ax.set_title(cbar_label)
177 cbar.set_ticks(z_ticks)
179 ax.set_xlabel(r"Re[$\alpha$]")
180 ax.set_ylabel(r"Im[$\alpha$]")
181 if subtitles is not None:
182 ax.set_title(subtitles[row, col])
184 fig = ax.get_figure()
185 fig.tight_layout()
186 if figtitle is not None:
187 fig.suptitle(figtitle, y=1.04)
188 return axs, im
191def plot_wigner(
192 state,
193 pts_x,
194 pts_y=None,
195 g=2,
196 axs=None,
197 contour=True,
198 cbar_label="",
199 axis_scale_factor=1,
200 plot_cbar=True,
201 x_ticks=None,
202 y_ticks=None,
203 z_ticks=None,
204 subtitles=None,
205 figtitle=None,
206):
207 """Plot the wigner function of the state.
210 Args:
211 state: state with arbitrary number of batch dimensions, result will
212 be flattened to a 2d grid to allow for plotting
213 pts_x: x points to evaluate quasi-probability distribution at
214 pts_y: y points to evaluate quasi-probability distribution at
215 g : float, default: 2
216 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
217 related to the value of :math:`\\hbar` in the commutation relation
218 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`.
219 axs: matplotlib axes to plot on
220 contour: make the plot use contouring
221 cbar_label: label for the cbar
222 axis_scale_factor: scale of the axes labels relative
223 plot_cbar: whether to plot cbar
224 x_ticks: tick position for the x-axis
225 y_ticks: tick position for the y-axis
226 z_ticks: tick position for the z-axis
227 subtitles: subtitles for the subplots
228 figtitle: figure title
230 Returns:
231 axis on which the plot was plotted.
232 """
233 return plot_qp(
234 state=state,
235 pts_x=pts_x,
236 pts_y=pts_y,
237 g=g,
238 axs=axs,
239 contour=contour,
240 qp_type=WIGNER,
241 cbar_label=cbar_label,
242 axis_scale_factor=axis_scale_factor,
243 plot_cbar=plot_cbar,
244 x_ticks=x_ticks,
245 y_ticks=y_ticks,
246 z_ticks=z_ticks,
247 subtitles=subtitles,
248 figtitle=figtitle,
249 )
252def plot_qfunc(
253 state,
254 pts_x,
255 pts_y=None,
256 g=2,
257 axs=None,
258 contour=True,
259 cbar_label="",
260 axis_scale_factor=1,
261 plot_cbar=True,
262 x_ticks=None,
263 y_ticks=None,
264 z_ticks=None,
265 subtitles=None,
266 figtitle=None,
267):
268 """Plot the husimi function of the state.
271 Args:
272 state: state with arbitrary number of batch dimensions, result will
273 be flattened to a 2d grid to allow for plotting
274 pts_x: x points to evaluate quasi-probability distribution at
275 pts_y: y points to evaluate quasi-probability distribution at
276 g : float, default: 2
277 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is
278 related to the value of :math:`\\hbar` in the commutation relation
279 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`.
280 axs: matplotlib axes to plot on
281 contour: make the plot use contouring
282 cbar_label: label for the cbar
283 axis_scale_factor: scale of the axes labels relative
284 plot_cbar: whether to plot cbar
285 x_ticks: tick position for the x-axis
286 y_ticks: tick position for the y-axis
287 z_ticks: tick position for the z-axis
288 subtitles: subtitles for the subplots
289 figtitle: figure title
291 Returns:
292 axis on which the plot was plotted.
293 """
294 return plot_qp(
295 state=state,
296 pts_x=pts_x,
297 pts_y=pts_y,
298 g=g,
299 axs=axs,
300 contour=contour,
301 qp_type=HUSIMI,
302 cbar_label=cbar_label,
303 axis_scale_factor=axis_scale_factor,
304 plot_cbar=plot_cbar,
305 x_ticks=x_ticks,
306 y_ticks=y_ticks,
307 z_ticks=z_ticks,
308 subtitles=subtitles,
309 figtitle=figtitle,
310 )
313def plot_cf(
314 state,
315 pts_x,
316 pts_y=None,
317 axs=None,
318 contour=True,
319 qp_type=WIGNER,
320 cbar_label="",
321 axis_scale_factor=1,
322 plot_cbar=True,
323 plot_grid=True,
324 x_ticks=None,
325 y_ticks=None,
326 z_ticks=None,
327 subtitles=None,
328 figtitle=None,
329):
330 """Plot characteristic function.
333 Args:
334 state: state with arbitrary number of batch dimensions, result will
335 be flattened to a 2d grid to allow for plotting
336 pts_x: x points to evaluate quasi-probability distribution at
337 pts_y: y points to evaluate quasi-probability distribution at
338 axs: matplotlib axes to plot on
339 contour: make the plot use contouring
340 qp_type: type of quasi probability distribution ("wigner")
341 cbar_label: labels for the real and imaginary cbar
342 axis_scale_factor: scale of the axes labels relative
343 plot_cbar: whether to plot cbar
344 x_ticks: tick position for the x-axis
345 y_ticks: tick position for the y-axis
346 z_ticks: tick position for the z-axis
347 subtitles: subtitles for the subplots
348 figtitle: figure title
350 Returns:
351 axis on which the plot was plotted.
352 """
353 if pts_y is None:
354 pts_y = pts_x
355 pts_x = jnp.array(pts_x)
356 pts_y = jnp.array(pts_y)
358 bdims = state.bdims
359 added_baxes = 0
361 if subtitles is not None:
362 if subtitles.shape != bdims:
363 raise ValueError(
364 f"labels must have same shape as bdims, "
365 f"got shapes {subtitles.shape} and {bdims}"
366 )
368 if len(bdims) == 0:
369 bdims = (1,)
370 added_baxes += 1
371 if len(bdims) == 1:
372 bdims = (1, bdims[0])
373 added_baxes += 1
375 extra_dims = bdims[2:]
376 if extra_dims != ():
377 state = state.reshape_bdims(
378 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
379 )
380 if subtitles is not None:
381 subtitles = subtitles.reshape(
382 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1]
383 )
384 bdims = state.bdims
386 if axs is None:
387 _, axs = plt.subplots(
388 bdims[0],
389 bdims[1]*2,
390 figsize=(4 * bdims[1]*2, 3 * bdims[0]),
391 dpi=200,
392 )
395 if qp_type == WIGNER:
396 vmin = -1
397 vmax = 1
398 scale = 1
399 cmap = "seismic"
400 cbar_label = [r"$\mathcal{Re}(\chi_W(\alpha))$", r"$\mathcal{"
401 r"Im}(\chi_W("
402 r"\alpha))$"]
403 QP = scale * cf_wigner(state, pts_x, pts_y)
405 for _ in range(added_baxes):
406 QP = jnp.array([QP])
407 axs = np.array([axs])
408 if subtitles is not None:
409 subtitles = np.array([subtitles])
411 if added_baxes==2:
412 axs = axs[0] # When the input state is zero-dimensional, remove an
413 # axis that is automatically added due to the subcolumns
416 pts_x = pts_x * axis_scale_factor
417 pts_y = pts_y * axis_scale_factor
419 x_ticks = (
420 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x),
421 5) if x_ticks is None else x_ticks
422 )
423 y_ticks = (
424 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y),
425 5) if y_ticks is None else y_ticks
426 )
427 z_ticks = jnp.linspace(vmin, vmax, 11) if z_ticks is None else z_ticks
428 print(axs.shape)
429 for row in range(bdims[0]):
430 for col in range(bdims[1]):
431 for subcol in range(2):
432 ax = axs[row, 2 * col + subcol]
433 if contour:
434 im = ax.contourf(
435 pts_x,
436 pts_y,
437 jnp.real(QP[row, col]) if subcol==0 else jnp.imag(QP[
438 row, col]),
439 cmap=cmap,
440 vmin=vmin,
441 vmax=vmax,
442 levels=np.linspace(vmin, vmax, 101),
443 )
444 else:
445 im = ax.pcolormesh(
446 pts_x,
447 pts_y,
448 jnp.real(QP[row, col]) if subcol == 0 else jnp.imag(QP[
449 row, col]),
450 cmap=cmap,
451 vmin=vmin,
452 vmax=vmax,
453 )
454 ax.set_xticks(x_ticks)
455 ax.set_yticks(y_ticks)
456 # ax.axhline(0, linestyle="-", color="black", alpha=0.7)
457 # ax.axvline(0, linestyle="-", color="black", alpha=0.7)
459 if plot_grid:
460 ax.grid()
462 ax.set_aspect("equal", adjustable="box")
464 if plot_cbar:
465 cbar = plt.colorbar(
466 im, ax=ax, orientation="vertical",
467 ticks=np.linspace(-1, 1, 11)
468 )
469 cbar.ax.set_title(cbar_label[subcol])
470 cbar.set_ticks(z_ticks)
472 ax.set_xlabel(r"Re[$\alpha$]")
473 ax.set_ylabel(r"Im[$\alpha$]")
474 if subtitles is not None:
475 ax.set_title(subtitles[row, col])
477 fig = ax.get_figure()
478 fig.tight_layout()
479 if figtitle is not None:
480 fig.suptitle(figtitle, y=1.04)
481 return axs, im
483def plot_cf_wigner(
484 state,
485 pts_x,
486 pts_y=None,
487 axs=None,
488 contour=True,
489 cbar_label="",
490 axis_scale_factor=1,
491 plot_cbar=True,
492 plot_grid=True,
493 x_ticks=None,
494 y_ticks=None,
495 z_ticks=None,
496 subtitles=None,
497 figtitle=None,
498):
499 """Plot the Wigner characteristic function of the state.
502 Args:
503 state: state with arbitrary number of batch dimensions, result will
504 be flattened to a 2d grid to allow for plotting
505 pts_x: x points to evaluate quasi-probability distribution at
506 pts_y: y points to evaluate quasi-probability distribution at
507 axs: matplotlib axes to plot on
508 contour: make the plot use contouring
509 cbar_label: label for the cbar
510 axis_scale_factor: scale of the axes labels relative
511 plot_cbar: whether to plot cbar
512 x_ticks: tick position for the x-axis
513 y_ticks: tick position for the y-axis
514 z_ticks: tick position for the z-axis
515 subtitles: subtitles for the subplots
516 figtitle: figure title
518 Returns:
519 axis on which the plot was plotted.
520 """
521 return plot_cf(
522 state=state,
523 pts_x=pts_x,
524 pts_y=pts_y,
525 axs=axs,
526 contour=contour,
527 qp_type=WIGNER,
528 cbar_label=cbar_label,
529 axis_scale_factor=axis_scale_factor,
530 plot_cbar=plot_cbar,
531 plot_grid=plot_grid,
532 x_ticks=x_ticks,
533 y_ticks=y_ticks,
534 z_ticks=z_ticks,
535 subtitles=subtitles,
536 figtitle=figtitle,
537 )