Coverage for jaxquantum/core/visualization.py: 22%

40 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Visualization utils. 

3""" 

4 

5import qutip as qt 

6import numpy as np 

7import matplotlib.pyplot as plt 

8 

9from jaxquantum.core.conversions import jqt2qt 

10 

11WIGNER = "wigner" 

12QFUNC = "qfunc" 

13 

14 

15def plot_qp( 

16 state, 

17 pts, 

18 ax=None, 

19 contour=True, 

20 qp_type=WIGNER, 

21 cbar_label="", 

22 axis_scale_factor=1, 

23 plot_cbar=True, 

24): 

25 """Plot quasi-probability distribution. 

26 

27 TODO: decouple this from qutip. 

28 

29 Args: 

30 state: statevector 

31 pts: points to evaluate quasi-probability distribution on 

32 dim: dimensions of state 

33 ax: matplotlib axis to plot on 

34 contour: make the plot use contouring 

35 qp_type: type of quasi probability distribution ("wigner", "qfunc") 

36 

37 Returns: 

38 axis on which the plot was plotted. 

39 """ 

40 pts = np.array(pts) 

41 state = jqt2qt(state) 

42 if ax is None: 

43 _, ax = plt.subplots(1, figsize=(4, 3), dpi=200) 

44 # fig = ax.get_figure() 

45 

46 if qp_type == WIGNER: 

47 vmin = -1 

48 vmax = 1 

49 scale = np.pi / 2 

50 cmap = "seismic" 

51 elif qp_type == QFUNC: 

52 vmin = 0 

53 vmax = 1 

54 scale = np.pi 

55 cmap = "jet" 

56 

57 QP = scale * getattr(qt, qp_type)(state, pts, pts, g=2) 

58 

59 pts = pts * axis_scale_factor 

60 

61 if contour: 

62 im = ax.contourf( 

63 pts, 

64 pts, 

65 QP, 

66 cmap=cmap, 

67 vmin=vmin, 

68 vmax=vmax, 

69 levels=np.linspace(vmin, vmax, 101), 

70 ) 

71 else: 

72 im = ax.pcolormesh( 

73 pts, 

74 pts, 

75 QP, 

76 cmap=cmap, 

77 vmin=vmin, 

78 vmax=vmax, 

79 ) 

80 ax.axhline(0, linestyle="-", color="black", alpha=0.7) 

81 ax.axvline(0, linestyle="-", color="black", alpha=0.7) 

82 ax.grid() 

83 ax.set_aspect("equal", adjustable="box") 

84 

85 if plot_cbar: 

86 cbar = plt.colorbar( 

87 im, ax=ax, orientation="vertical", ticks=np.linspace(-1, 1, 11) 

88 ) 

89 cbar.ax.set_title(cbar_label) 

90 

91 ax.set_xlabel(r"Re[$\alpha$]") 

92 ax.set_ylabel(r"Im[$\alpha$]") 

93 

94 fig = ax.get_figure() 

95 fig.tight_layout() 

96 return ax, im 

97 

98 

99plot_wigner = lambda state, pts, ax=None, contour=True, **kwargs: plot_qp( 

100 state, 

101 pts, 

102 ax=ax, 

103 contour=contour, 

104 qp_type=WIGNER, 

105 cbar_label=r"$\mathcal{W}(\alpha)$", 

106 **kwargs, 

107) 

108 

109plot_qfunc = lambda state, pts, ax=None, contour=True, **kwargs: plot_qp( 

110 state, 

111 pts, 

112 ax=ax, 

113 contour=contour, 

114 qp_type=QFUNC, 

115 cbar_label=r"$\mathcal{Q}(\alpha)$", 

116 **kwargs, 

117)