Coverage for jaxquantum / core / visualization.py: 91%

157 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-12-03 20:38 +0000

1""" 

2Visualization utils. 

3""" 

4 

5import matplotlib.pyplot as plt 

6 

7from jaxquantum.core.qp_distributions import wigner, qfunc 

8from jaxquantum.core.cfunctions import cf_wigner 

9import jax.numpy as jnp 

10import numpy as np 

11 

12WIGNER = "wigner" 

13HUSIMI = "husimi" 

14 

15 

16def plot_qp( 

17 state, 

18 pts_x, 

19 pts_y=None, 

20 g=2, 

21 axs=None, 

22 contour=True, 

23 qp_type=WIGNER, 

24 cbar_label="", 

25 axis_scale_factor=1, 

26 plot_cbar=True, 

27 x_ticks=None, 

28 y_ticks=None, 

29 z_ticks=None, 

30 subtitles=None, 

31 figtitle=None, 

32): 

33 """Plot quasi-probability distribution. 

34 

35 

36 Args: 

37 state: state with arbitrary number of batch dimensions, result will 

38 be flattened to a 2d grid to allow for plotting 

39 pts_x: x points to evaluate quasi-probability distribution at 

40 pts_y: y points to evaluate quasi-probability distribution at 

41 g : float, default: 2 

42 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

43 related to the value of :math:`\\hbar` in the commutation relation 

44 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`. 

45 axs: matplotlib axes to plot on 

46 contour: make the plot use contouring 

47 qp_type: type of quasi probability distribution ("wigner", "qfunc") 

48 cbar_label: label for the cbar 

49 axis_scale_factor: scale of the axes labels relative 

50 plot_cbar: whether to plot cbar 

51 x_ticks: tick position for the x-axis 

52 y_ticks: tick position for the y-axis 

53 z_ticks: tick position for the z-axis 

54 subtitles: subtitles for the subplots 

55 figtitle: figure title 

56 

57 Returns: 

58 axis on which the plot was plotted. 

59 """ 

60 if pts_y is None: 

61 pts_y = pts_x 

62 pts_x = jnp.array(pts_x) 

63 pts_y = jnp.array(pts_y) 

64 

65 if len(state.bdims)==1 and state.bdims[0]==1: 

66 state = state[0] 

67 

68 

69 bdims = state.bdims 

70 added_baxes = 0 

71 

72 if subtitles is not None: 

73 if subtitles.shape != bdims: 

74 raise ValueError( 

75 f"labels must have same shape as bdims, " 

76 f"got shapes {subtitles.shape} and {bdims}" 

77 ) 

78 

79 if len(bdims) == 0: 

80 bdims = (1,) 

81 added_baxes += 1 

82 if len(bdims) == 1: 

83 bdims = (1, bdims[0]) 

84 added_baxes += 1 

85 

86 extra_dims = bdims[2:] 

87 if extra_dims != (): 

88 state = state.reshape_bdims( 

89 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

90 ) 

91 if subtitles is not None: 

92 subtitles = subtitles.reshape( 

93 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

94 ) 

95 bdims = state.bdims 

96 

97 if axs is None: 

98 _, axs = plt.subplots( 

99 bdims[0], 

100 bdims[1], 

101 figsize=(4 * bdims[1], 3 * bdims[0]), 

102 dpi=200, 

103 ) 

104 

105 if qp_type == WIGNER: 

106 vmin = -1 

107 vmax = 1 

108 scale = np.pi / 2 

109 cmap = "seismic" 

110 cbar_label = r"$\mathcal{W}(\alpha)$" 

111 QP = scale * wigner(state, pts_x, pts_y, g=g) 

112 

113 elif qp_type == HUSIMI: 

114 vmin = 0 

115 vmax = 1 

116 scale = np.pi 

117 cmap = "jet" 

118 cbar_label = r"$\mathcal{Q}(\alpha)$" 

119 QP = scale * qfunc(state, pts_x, pts_y, g=g) 

120 

121 

122 

123 for _ in range(added_baxes): 

124 QP = jnp.array([QP]) 

125 axs = np.array([axs]) 

126 if subtitles is not None: 

127 subtitles = np.array([subtitles]) 

128 

129 

130 

131 

132 pts_x = pts_x * axis_scale_factor 

133 pts_y = pts_y * axis_scale_factor 

134 

135 x_ticks = ( 

136 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x), 5) if x_ticks is None else x_ticks 

137 ) 

138 y_ticks = ( 

139 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y), 5) if y_ticks is None else y_ticks 

140 ) 

141 z_ticks = jnp.linspace(vmin, vmax, 3) if z_ticks is None else z_ticks 

142 

143 for row in range(bdims[0]): 

144 for col in range(bdims[1]): 

145 ax = axs[row, col] 

146 if contour: 

147 im = ax.contourf( 

148 pts_x, 

149 pts_y, 

150 QP[row, col], 

151 cmap=cmap, 

152 vmin=vmin, 

153 vmax=vmax, 

154 levels=np.linspace(vmin, vmax, 101), 

155 ) 

156 else: 

157 im = ax.pcolormesh( 

158 pts_x, 

159 pts_y, 

160 QP[row, col], 

161 cmap=cmap, 

162 vmin=vmin, 

163 vmax=vmax, 

164 ) 

165 ax.set_xticks(x_ticks) 

166 ax.set_yticks(y_ticks) 

167 ax.axhline(0, linestyle="-", color="black", alpha=0.7) 

168 ax.axvline(0, linestyle="-", color="black", alpha=0.7) 

169 ax.grid() 

170 ax.set_aspect("equal", adjustable="box") 

171 

172 if plot_cbar: 

173 cbar = plt.colorbar( 

174 im, ax=ax, orientation="vertical", ticks=np.linspace(-1, 1, 11) 

175 ) 

176 cbar.ax.set_title(cbar_label) 

177 cbar.set_ticks(z_ticks) 

178 

179 ax.set_xlabel(r"Re[$\alpha$]") 

180 ax.set_ylabel(r"Im[$\alpha$]") 

181 if subtitles is not None: 

182 ax.set_title(subtitles[row, col]) 

183 

184 fig = ax.get_figure() 

185 fig.tight_layout() 

186 if figtitle is not None: 

187 fig.suptitle(figtitle, y=1.04) 

188 return axs, im 

189 

190 

191def plot_wigner( 

192 state, 

193 pts_x, 

194 pts_y=None, 

195 g=2, 

196 axs=None, 

197 contour=True, 

198 cbar_label="", 

199 axis_scale_factor=1, 

200 plot_cbar=True, 

201 x_ticks=None, 

202 y_ticks=None, 

203 z_ticks=None, 

204 subtitles=None, 

205 figtitle=None, 

206): 

207 """Plot the wigner function of the state. 

208 

209 

210 Args: 

211 state: state with arbitrary number of batch dimensions, result will 

212 be flattened to a 2d grid to allow for plotting 

213 pts_x: x points to evaluate quasi-probability distribution at 

214 pts_y: y points to evaluate quasi-probability distribution at 

215 g : float, default: 2 

216 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

217 related to the value of :math:`\\hbar` in the commutation relation 

218 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`. 

219 axs: matplotlib axes to plot on 

220 contour: make the plot use contouring 

221 cbar_label: label for the cbar 

222 axis_scale_factor: scale of the axes labels relative 

223 plot_cbar: whether to plot cbar 

224 x_ticks: tick position for the x-axis 

225 y_ticks: tick position for the y-axis 

226 z_ticks: tick position for the z-axis 

227 subtitles: subtitles for the subplots 

228 figtitle: figure title 

229 

230 Returns: 

231 axis on which the plot was plotted. 

232 """ 

233 return plot_qp( 

234 state=state, 

235 pts_x=pts_x, 

236 pts_y=pts_y, 

237 g=g, 

238 axs=axs, 

239 contour=contour, 

240 qp_type=WIGNER, 

241 cbar_label=cbar_label, 

242 axis_scale_factor=axis_scale_factor, 

243 plot_cbar=plot_cbar, 

244 x_ticks=x_ticks, 

245 y_ticks=y_ticks, 

246 z_ticks=z_ticks, 

247 subtitles=subtitles, 

248 figtitle=figtitle, 

249 ) 

250 

251 

252def plot_qfunc( 

253 state, 

254 pts_x, 

255 pts_y=None, 

256 g=2, 

257 axs=None, 

258 contour=True, 

259 cbar_label="", 

260 axis_scale_factor=1, 

261 plot_cbar=True, 

262 x_ticks=None, 

263 y_ticks=None, 

264 z_ticks=None, 

265 subtitles=None, 

266 figtitle=None, 

267): 

268 """Plot the husimi function of the state. 

269 

270 

271 Args: 

272 state: state with arbitrary number of batch dimensions, result will 

273 be flattened to a 2d grid to allow for plotting 

274 pts_x: x points to evaluate quasi-probability distribution at 

275 pts_y: y points to evaluate quasi-probability distribution at 

276 g : float, default: 2 

277 Scaling factor for ``a = 0.5 * g * (x + iy)``. The value of `g` is 

278 related to the value of :math:`\\hbar` in the commutation relation 

279 :math:`[x,\,y] = i\\hbar` via :math:`\\hbar=2/g^2`. 

280 axs: matplotlib axes to plot on 

281 contour: make the plot use contouring 

282 cbar_label: label for the cbar 

283 axis_scale_factor: scale of the axes labels relative 

284 plot_cbar: whether to plot cbar 

285 x_ticks: tick position for the x-axis 

286 y_ticks: tick position for the y-axis 

287 z_ticks: tick position for the z-axis 

288 subtitles: subtitles for the subplots 

289 figtitle: figure title 

290 

291 Returns: 

292 axis on which the plot was plotted. 

293 """ 

294 return plot_qp( 

295 state=state, 

296 pts_x=pts_x, 

297 pts_y=pts_y, 

298 g=g, 

299 axs=axs, 

300 contour=contour, 

301 qp_type=HUSIMI, 

302 cbar_label=cbar_label, 

303 axis_scale_factor=axis_scale_factor, 

304 plot_cbar=plot_cbar, 

305 x_ticks=x_ticks, 

306 y_ticks=y_ticks, 

307 z_ticks=z_ticks, 

308 subtitles=subtitles, 

309 figtitle=figtitle, 

310 ) 

311 

312 

313def plot_cf( 

314 state, 

315 pts_x, 

316 pts_y=None, 

317 axs=None, 

318 contour=True, 

319 qp_type=WIGNER, 

320 cbar_label="", 

321 axis_scale_factor=1, 

322 plot_cbar=True, 

323 plot_grid=True, 

324 x_ticks=None, 

325 y_ticks=None, 

326 z_ticks=None, 

327 subtitles=None, 

328 figtitle=None, 

329): 

330 """Plot characteristic function. 

331 

332 

333 Args: 

334 state: state with arbitrary number of batch dimensions, result will 

335 be flattened to a 2d grid to allow for plotting 

336 pts_x: x points to evaluate quasi-probability distribution at 

337 pts_y: y points to evaluate quasi-probability distribution at 

338 axs: matplotlib axes to plot on 

339 contour: make the plot use contouring 

340 qp_type: type of quasi probability distribution ("wigner") 

341 cbar_label: labels for the real and imaginary cbar 

342 axis_scale_factor: scale of the axes labels relative 

343 plot_cbar: whether to plot cbar 

344 x_ticks: tick position for the x-axis 

345 y_ticks: tick position for the y-axis 

346 z_ticks: tick position for the z-axis 

347 subtitles: subtitles for the subplots 

348 figtitle: figure title 

349 

350 Returns: 

351 axis on which the plot was plotted. 

352 """ 

353 if pts_y is None: 

354 pts_y = pts_x 

355 pts_x = jnp.array(pts_x) 

356 pts_y = jnp.array(pts_y) 

357 

358 bdims = state.bdims 

359 added_baxes = 0 

360 

361 if subtitles is not None: 

362 if subtitles.shape != bdims: 

363 raise ValueError( 

364 f"labels must have same shape as bdims, " 

365 f"got shapes {subtitles.shape} and {bdims}" 

366 ) 

367 

368 if len(bdims) == 0: 

369 bdims = (1,) 

370 added_baxes += 1 

371 if len(bdims) == 1: 

372 bdims = (1, bdims[0]) 

373 added_baxes += 1 

374 

375 extra_dims = bdims[2:] 

376 if extra_dims != (): 

377 state = state.reshape_bdims( 

378 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

379 ) 

380 if subtitles is not None: 

381 subtitles = subtitles.reshape( 

382 bdims[0] * int(jnp.prod(jnp.array(extra_dims))), bdims[1] 

383 ) 

384 bdims = state.bdims 

385 

386 if axs is None: 

387 _, axs = plt.subplots( 

388 bdims[0], 

389 bdims[1]*2, 

390 figsize=(4 * bdims[1]*2, 3 * bdims[0]), 

391 dpi=200, 

392 ) 

393 

394 

395 if qp_type == WIGNER: 

396 vmin = -1 

397 vmax = 1 

398 scale = 1 

399 cmap = "seismic" 

400 cbar_label = [r"$\mathcal{Re}(\chi_W(\alpha))$", r"$\mathcal{" 

401 r"Im}(\chi_W(" 

402 r"\alpha))$"] 

403 QP = scale * cf_wigner(state, pts_x, pts_y) 

404 

405 for _ in range(added_baxes): 

406 QP = jnp.array([QP]) 

407 axs = np.array([axs]) 

408 if subtitles is not None: 

409 subtitles = np.array([subtitles]) 

410 

411 if added_baxes==2: 

412 axs = axs[0] # When the input state is zero-dimensional, remove an 

413 # axis that is automatically added due to the subcolumns 

414 

415 

416 pts_x = pts_x * axis_scale_factor 

417 pts_y = pts_y * axis_scale_factor 

418 

419 x_ticks = ( 

420 jnp.linspace(jnp.min(pts_x), jnp.max(pts_x), 

421 5) if x_ticks is None else x_ticks 

422 ) 

423 y_ticks = ( 

424 jnp.linspace(jnp.min(pts_y), jnp.max(pts_y), 

425 5) if y_ticks is None else y_ticks 

426 ) 

427 z_ticks = jnp.linspace(vmin, vmax, 11) if z_ticks is None else z_ticks 

428 print(axs.shape) 

429 for row in range(bdims[0]): 

430 for col in range(bdims[1]): 

431 for subcol in range(2): 

432 ax = axs[row, 2 * col + subcol] 

433 if contour: 

434 im = ax.contourf( 

435 pts_x, 

436 pts_y, 

437 jnp.real(QP[row, col]) if subcol==0 else jnp.imag(QP[ 

438 row, col]), 

439 cmap=cmap, 

440 vmin=vmin, 

441 vmax=vmax, 

442 levels=np.linspace(vmin, vmax, 101), 

443 ) 

444 else: 

445 im = ax.pcolormesh( 

446 pts_x, 

447 pts_y, 

448 jnp.real(QP[row, col]) if subcol == 0 else jnp.imag(QP[ 

449 row, col]), 

450 cmap=cmap, 

451 vmin=vmin, 

452 vmax=vmax, 

453 ) 

454 ax.set_xticks(x_ticks) 

455 ax.set_yticks(y_ticks) 

456 # ax.axhline(0, linestyle="-", color="black", alpha=0.7) 

457 # ax.axvline(0, linestyle="-", color="black", alpha=0.7) 

458 

459 if plot_grid: 

460 ax.grid() 

461 

462 ax.set_aspect("equal", adjustable="box") 

463 

464 if plot_cbar: 

465 cbar = plt.colorbar( 

466 im, ax=ax, orientation="vertical", 

467 ticks=np.linspace(-1, 1, 11) 

468 ) 

469 cbar.ax.set_title(cbar_label[subcol]) 

470 cbar.set_ticks(z_ticks) 

471 

472 ax.set_xlabel(r"Re[$\alpha$]") 

473 ax.set_ylabel(r"Im[$\alpha$]") 

474 if subtitles is not None: 

475 ax.set_title(subtitles[row, col]) 

476 

477 fig = ax.get_figure() 

478 fig.tight_layout() 

479 if figtitle is not None: 

480 fig.suptitle(figtitle, y=1.04) 

481 return axs, im 

482 

483def plot_cf_wigner( 

484 state, 

485 pts_x, 

486 pts_y=None, 

487 axs=None, 

488 contour=True, 

489 cbar_label="", 

490 axis_scale_factor=1, 

491 plot_cbar=True, 

492 plot_grid=True, 

493 x_ticks=None, 

494 y_ticks=None, 

495 z_ticks=None, 

496 subtitles=None, 

497 figtitle=None, 

498): 

499 """Plot the Wigner characteristic function of the state. 

500 

501 

502 Args: 

503 state: state with arbitrary number of batch dimensions, result will 

504 be flattened to a 2d grid to allow for plotting 

505 pts_x: x points to evaluate quasi-probability distribution at 

506 pts_y: y points to evaluate quasi-probability distribution at 

507 axs: matplotlib axes to plot on 

508 contour: make the plot use contouring 

509 cbar_label: label for the cbar 

510 axis_scale_factor: scale of the axes labels relative 

511 plot_cbar: whether to plot cbar 

512 x_ticks: tick position for the x-axis 

513 y_ticks: tick position for the y-axis 

514 z_ticks: tick position for the z-axis 

515 subtitles: subtitles for the subplots 

516 figtitle: figure title 

517 

518 Returns: 

519 axis on which the plot was plotted. 

520 """ 

521 return plot_cf( 

522 state=state, 

523 pts_x=pts_x, 

524 pts_y=pts_y, 

525 axs=axs, 

526 contour=contour, 

527 qp_type=WIGNER, 

528 cbar_label=cbar_label, 

529 axis_scale_factor=axis_scale_factor, 

530 plot_cbar=plot_cbar, 

531 plot_grid=plot_grid, 

532 x_ticks=x_ticks, 

533 y_ticks=y_ticks, 

534 z_ticks=z_ticks, 

535 subtitles=subtitles, 

536 figtitle=figtitle, 

537 )