Coverage for jaxquantum/core/visualization.py: 91%

158 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1""" 

2Visualization utils. 

3""" 

4 

5import matplotlib.pyplot as plt 

6 

7from jaxquantum.core.qp_distributions import wigner, qfunc 

8from jaxquantum.core.cfunctions import cf_wigner 

9import jax.numpy as jnp 

10import numpy as np 

11 

12WIGNER = "wigner" 

13HUSIMI = "husimi" 

14 

15 

16def plot_qp( 

17 state, 

18 pts_x, 

19 pts_y=None, 

20 g=2, 

21 axs=None, 

22 contour=True, 

23 qp_type=WIGNER, 

24 cbar_label="", 

25 axis_scale_factor=1, 

26 plot_cbar=True, 

27 x_ticks=None, 

28 y_ticks=None, 

29 z_ticks=None, 

30 subtitles=None, 

31 figtitle=None, 

32): 

33 """Plot quasi-probability distribution. 

34 

35 

36 Args: 

37 state: state with arbitrary number of batch dimensions, result will 

38 be flattened to a 2d grid to allow for plotting 

39 pts_x: x points to evaluate quasi-probability distribution at 

40 pts_y: y points to evaluate quasi-probability distribution at 

41 g : float, default: 2 

42 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

43 related to the value of :math:`\\hbar` in the commutation relation 

44 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`. 

45 axs: matplotlib axes to plot on 

46 contour: make the plot use contouring 

47 qp_type: type of quasi probability distribution ("wigner", "qfunc") 

48 cbar_label: label for the cbar 

49 axis_scale_factor: scale of the axes labels relative 

50 plot_cbar: whether to plot cbar 

51 x_ticks: tick position for the x-axis 

52 y_ticks: tick position for the y-axis 

53 z_ticks: tick position for the z-axis 

54 subtitles: subtitles for the subplots 

55 figtitle: figure title 

56 

57 Returns: 

58 axis on which the plot was plotted. 

59 """ 

60 if pts_y is None: 

61 pts_y = pts_x 

62 pts_x = jnp.array(pts_x) 

63 pts_y = jnp.array(pts_y) 

64 

65 if len(state.bdims)==1 and state.bdims[0]==1: 

66 state = state[0] 

67 

68 

69 bdims = state.bdims 

70 added_baxes = 0 

71 

72 if subtitles is not None: 

73 if subtitles.shape != bdims: 

74 raise ValueError( 

75 f"labels must have same shape as bdims, " 

76 f"got shapes {subtitles.shape} and {bdims}" 

77 ) 

78 

79 if len(bdims) == 0: 

80 bdims = (1,) 

81 added_baxes += 1 

82 if len(bdims) == 1: 

83 bdims = (1, bdims[0]) 

84 added_baxes += 1 

85 

86 extra_dims = bdims[2:] 

87 if extra_dims != (): 

88 state = state.reshape_bdims( 

89 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

90 ) 

91 if subtitles is not None: 

92 subtitles = subtitles.reshape( 

93 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

94 ) 

95 bdims = state.bdims 

96 

97 if axs is None: 

98 _, axs = plt.subplots( 

99 bdims[0], 

100 bdims[1], 

101 figsize=(4 * bdims[1], 3 * bdims[0]), 

102 dpi=200, 

103 ) 

104 

105 if qp_type == WIGNER: 

106 vmin = -1 

107 vmax = 1 

108 scale = np.pi / 2 

109 cmap = "seismic" 

110 cbar_label = r"$\mathcal{W}(\alpha)$" 

111 QP = scale * wigner(state, pts_x, pts_y, g=g) 

112 

113 elif qp_type == HUSIMI: 

114 vmin = 0 

115 vmax = 1 

116 scale = np.pi 

117 cmap = "jet" 

118 cbar_label = r"$\mathcal{Q}(\alpha)$" 

119 QP = scale * qfunc(state, pts_x, pts_y, g=g) 

120 

121 

122 

123 for _ in range(added_baxes): 

124 QP = jnp.array([QP]) 

125 axs = np.array([axs]) 

126 if subtitles is not None: 

127 subtitles = np.array([subtitles]) 

128 

129 

130 

131 

132 pts_x = pts_x * axis_scale_factor 

133 pts_y = pts_y * axis_scale_factor 

134 

135 x_ticks = ( 

136 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x), 5) if x_ticks is None else x_ticks 

137 ) 

138 y_ticks = ( 

139 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y), 5) if y_ticks is None else y_ticks 

140 ) 

141 z_ticks = jnp.linspace(vmin, vmax, 3) if z_ticks is None else z_ticks 

142 

143 for row in range(bdims[0]): 

144 for col in range(bdims[1]): 

145 ax = axs[row, col] 

146 if contour: 

147 im = ax.contourf( 

148 pts_x, 

149 pts_y, 

150 QP[row, col], 

151 cmap=cmap, 

152 vmin=vmin, 

153 vmax=vmax, 

154 levels=np.linspace(vmin, vmax, 101), 

155 ) 

156 else: 

157 im = ax.pcolormesh( 

158 pts_x, 

159 pts_y, 

160 QP[row, col], 

161 cmap=cmap, 

162 vmin=vmin, 

163 vmax=vmax, 

164 ) 

165 ax.set_xticks(x_ticks) 

166 ax.set_yticks(y_ticks) 

167 ax.axhline(0, linestyle="-", color="black", alpha=0.7) 

168 ax.axvline(0, linestyle="-", color="black", alpha=0.7) 

169 ax.grid() 

170 ax.set_aspect("equal", adjustable="box") 

171 

172 if plot_cbar: 

173 cbar = plt.colorbar( 

174 im, ax=ax, orientation="vertical", ticks=np.linspace(-1, 1, 11) 

175 ) 

176 cbar.ax.set_title(cbar_label) 

177 cbar.set_ticks(z_ticks) 

178 

179 ax.set_xlabel(r"Re[$\alpha$]") 

180 ax.set_ylabel(r"Im[$\alpha$]") 

181 if subtitles is not None: 

182 ax.set_title(subtitles[row, col]) 

183 

184 fig = ax.get_figure() 

185 fig.tight_layout() 

186 if figtitle is not None: 

187 fig.suptitle(figtitle, y=1.04) 

188 return axs, im 

189 

190 

191def plot_wigner( 

192 state, 

193 pts_x, 

194 pts_y=None, 

195 g=2, 

196 axs=None, 

197 contour=True, 

198 cbar_label="", 

199 axis_scale_factor=1, 

200 plot_cbar=True, 

201 x_ticks=None, 

202 y_ticks=None, 

203 z_ticks=None, 

204 subtitles=None, 

205 figtitle=None, 

206): 

207 """Plot the wigner function of the state. 

208 

209 

210 Args: 

211 state: state with arbitrary number of batch dimensions, result will 

212 be flattened to a 2d grid to allow for plotting 

213 pts_x: x points to evaluate quasi-probability distribution at 

214 pts_y: y points to evaluate quasi-probability distribution at 

215 g : float, default: 2 

216 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

217 related to the value of :math:`\\hbar` in the commutation relation 

218 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`. 

219 axs: matplotlib axes to plot on 

220 contour: make the plot use contouring 

221 cbar_label: label for the cbar 

222 axis_scale_factor: scale of the axes labels relative 

223 plot_cbar: whether to plot cbar 

224 x_ticks: tick position for the x-axis 

225 y_ticks: tick position for the y-axis 

226 z_ticks: tick position for the z-axis 

227 subtitles: subtitles for the subplots 

228 figtitle: figure title 

229 

230 Returns: 

231 axis on which the plot was plotted. 

232 """ 

233 return plot_qp( 

234 state=state, 

235 pts_x=pts_x, 

236 pts_y=pts_y, 

237 g=g, 

238 axs=axs, 

239 contour=contour, 

240 qp_type=WIGNER, 

241 cbar_label=cbar_label, 

242 axis_scale_factor=axis_scale_factor, 

243 plot_cbar=plot_cbar, 

244 x_ticks=x_ticks, 

245 y_ticks=y_ticks, 

246 z_ticks=z_ticks, 

247 subtitles=subtitles, 

248 figtitle=figtitle, 

249 ) 

250 

251 

252def plot_qfunc( 

253 state, 

254 pts_x, 

255 pts_y=None, 

256 g=2, 

257 axs=None, 

258 contour=True, 

259 cbar_label="", 

260 axis_scale_factor=1, 

261 plot_cbar=True, 

262 x_ticks=None, 

263 y_ticks=None, 

264 z_ticks=None, 

265 subtitles=None, 

266 figtitle=None, 

267): 

268 """Plot the husimi function of the state. 

269 

270 

271 Args: 

272 state: state with arbitrary number of batch dimensions, result will 

273 be flattened to a 2d grid to allow for plotting 

274 pts_x: x points to evaluate quasi-probability distribution at 

275 pts_y: y points to evaluate quasi-probability distribution at 

276 g : float, default: 2 

277 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

278 related to the value of :math:`\\hbar` in the commutation relation 

279 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`. 

280 axs: matplotlib axes to plot on 

281 contour: make the plot use contouring 

282 cbar_label: label for the cbar 

283 axis_scale_factor: scale of the axes labels relative 

284 plot_cbar: whether to plot cbar 

285 x_ticks: tick position for the x-axis 

286 y_ticks: tick position for the y-axis 

287 z_ticks: tick position for the z-axis 

288 subtitles: subtitles for the subplots 

289 figtitle: figure title 

290 

291 Returns: 

292 axis on which the plot was plotted. 

293 """ 

294 return plot_qp( 

295 state=state, 

296 pts_x=pts_x, 

297 pts_y=pts_y, 

298 g=g, 

299 axs=axs, 

300 contour=contour, 

301 qp_type=HUSIMI, 

302 cbar_label=cbar_label, 

303 axis_scale_factor=axis_scale_factor, 

304 plot_cbar=plot_cbar, 

305 x_ticks=x_ticks, 

306 y_ticks=y_ticks, 

307 z_ticks=z_ticks, 

308 subtitles=subtitles, 

309 figtitle=figtitle, 

310 ) 

311 

312 

313def plot_cf( 

314 state, 

315 pts_x, 

316 pts_y=None, 

317 axs=None, 

318 contour=True, 

319 qp_type=WIGNER, 

320 cbar_label="", 

321 axis_scale_factor=1, 

322 plot_cbar=True, 

323 x_ticks=None, 

324 y_ticks=None, 

325 z_ticks=None, 

326 subtitles=None, 

327 figtitle=None, 

328): 

329 """Plot characteristic function. 

330 

331 

332 Args: 

333 state: state with arbitrary number of batch dimensions, result will 

334 be flattened to a 2d grid to allow for plotting 

335 pts_x: x points to evaluate quasi-probability distribution at 

336 pts_y: y points to evaluate quasi-probability distribution at 

337 axs: matplotlib axes to plot on 

338 contour: make the plot use contouring 

339 qp_type: type of quasi probability distribution ("wigner") 

340 cbar_label: labels for the real and imaginary cbar 

341 axis_scale_factor: scale of the axes labels relative 

342 plot_cbar: whether to plot cbar 

343 x_ticks: tick position for the x-axis 

344 y_ticks: tick position for the y-axis 

345 z_ticks: tick position for the z-axis 

346 subtitles: subtitles for the subplots 

347 figtitle: figure title 

348 

349 Returns: 

350 axis on which the plot was plotted. 

351 """ 

352 if pts_y is None: 

353 pts_y = pts_x 

354 pts_x = jnp.array(pts_x) 

355 pts_y = jnp.array(pts_y) 

356 

357 bdims = state.bdims 

358 added_baxes = 0 

359 

360 if subtitles is not None: 

361 if subtitles.shape != bdims: 

362 raise ValueError( 

363 f"labels must have same shape as bdims, " 

364 f"got shapes {subtitles.shape} and {bdims}" 

365 ) 

366 

367 if len(bdims) == 0: 

368 bdims = (1,) 

369 added_baxes += 1 

370 if len(bdims) == 1: 

371 bdims = (1, bdims[0]) 

372 added_baxes += 1 

373 

374 extra_dims = bdims[2:] 

375 if extra_dims != (): 

376 state = state.reshape_bdims( 

377 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

378 ) 

379 if subtitles is not None: 

380 subtitles = subtitles.reshape( 

381 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

382 ) 

383 bdims = state.bdims 

384 

385 if axs is None: 

386 _, axs = plt.subplots( 

387 bdims[0], 

388 bdims[1]*2, 

389 figsize=(4 * bdims[1]*2, 3 * bdims[0]), 

390 dpi=200, 

391 ) 

392 

393 

394 if qp_type == WIGNER: 

395 vmin = -1 

396 vmax = 1 

397 scale = 1 

398 cmap = "seismic" 

399 cbar_label = [r"$\mathcal{Re}(\chi_W(\alpha))$", r"$\mathcal{" 

400 r"Im}(\chi_W(" 

401 r"\alpha))$"] 

402 QP = scale * cf_wigner(state, pts_x, pts_y) 

403 

404 for _ in range(added_baxes): 

405 QP = jnp.array([QP]) 

406 axs = np.array([axs]) 

407 if subtitles is not None: 

408 subtitles = np.array([subtitles]) 

409 

410 if added_baxes==2: 

411 axs = axs[0] # When the input state is zero-dimensional, remove an 

412 # axis that is automatically added due to the subcolumns 

413 

414 

415 pts_x = pts_x * axis_scale_factor 

416 pts_y = pts_y * axis_scale_factor 

417 

418 x_ticks = ( 

419 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x), 

420 5) if x_ticks is None else x_ticks 

421 ) 

422 y_ticks = ( 

423 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y), 

424 5) if y_ticks is None else y_ticks 

425 ) 

426 z_ticks = jnp.linspace(vmin, vmax, 11) if z_ticks is None else z_ticks 

427 print(axs.shape) 

428 for row in range(bdims[0]): 

429 for col in range(bdims[1]): 

430 for subcol in range(2): 

431 ax = axs[row, 2 * col + subcol] 

432 if contour: 

433 im = ax.contourf( 

434 pts_x, 

435 pts_y, 

436 jnp.real(QP[row, col]) if subcol==0 else jnp.imag(QP[ 

437 row, col]), 

438 cmap=cmap, 

439 vmin=vmin, 

440 vmax=vmax, 

441 levels=np.linspace(vmin, vmax, 101), 

442 ) 

443 else: 

444 im = ax.pcolormesh( 

445 pts_x, 

446 pts_y, 

447 jnp.real(QP[row, col]) if subcol == 0 else jnp.imag(QP[ 

448 row, col]), 

449 cmap=cmap, 

450 vmin=vmin, 

451 vmax=vmax, 

452 ) 

453 ax.set_xticks(x_ticks) 

454 ax.set_yticks(y_ticks) 

455 ax.axhline(0, linestyle="-", color="black", alpha=0.7) 

456 ax.axvline(0, linestyle="-", color="black", alpha=0.7) 

457 ax.grid() 

458 ax.set_aspect("equal", adjustable="box") 

459 

460 if plot_cbar: 

461 cbar = plt.colorbar( 

462 im, ax=ax, orientation="vertical", 

463 ticks=np.linspace(-1, 1, 11) 

464 ) 

465 cbar.ax.set_title(cbar_label[subcol]) 

466 cbar.set_ticks(z_ticks) 

467 

468 ax.set_xlabel(r"Re[$\alpha$]") 

469 ax.set_ylabel(r"Im[$\alpha$]") 

470 if subtitles is not None: 

471 ax.set_title(subtitles[row, col]) 

472 

473 fig = ax.get_figure() 

474 fig.tight_layout() 

475 if figtitle is not None: 

476 fig.suptitle(figtitle, y=1.04) 

477 return axs, im 

478 

479def plot_cf_wigner( 

480 state, 

481 pts_x, 

482 pts_y=None, 

483 axs=None, 

484 contour=True, 

485 cbar_label="", 

486 axis_scale_factor=1, 

487 plot_cbar=True, 

488 x_ticks=None, 

489 y_ticks=None, 

490 z_ticks=None, 

491 subtitles=None, 

492 figtitle=None, 

493): 

494 """Plot the Wigner characteristic function of the state. 

495 

496 

497 Args: 

498 state: state with arbitrary number of batch dimensions, result will 

499 be flattened to a 2d grid to allow for plotting 

500 pts_x: x points to evaluate quasi-probability distribution at 

501 pts_y: y points to evaluate quasi-probability distribution at 

502 axs: matplotlib axes to plot on 

503 contour: make the plot use contouring 

504 cbar_label: label for the cbar 

505 axis_scale_factor: scale of the axes labels relative 

506 plot_cbar: whether to plot cbar 

507 x_ticks: tick position for the x-axis 

508 y_ticks: tick position for the y-axis 

509 z_ticks: tick position for the z-axis 

510 subtitles: subtitles for the subplots 

511 figtitle: figure title 

512 

513 Returns: 

514 axis on which the plot was plotted. 

515 """ 

516 return plot_cf( 

517 state=state, 

518 pts_x=pts_x, 

519 pts_y=pts_y, 

520 axs=axs, 

521 contour=contour, 

522 qp_type=WIGNER, 

523 cbar_label=cbar_label, 

524 axis_scale_factor=axis_scale_factor, 

525 plot_cbar=plot_cbar, 

526 x_ticks=x_ticks, 

527 y_ticks=y_ticks, 

528 z_ticks=z_ticks, 

529 subtitles=subtitles, 

530 figtitle=figtitle, 

531 )