Coverage for jaxquantum/core/qarray.py: 71%

437 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""QArray.""" 

2 

3from __future__ import annotations 

4 

5from flax import struct 

6from jax import Array, config 

7from typing import List, Union 

8 

9 

10from math import prod 

11from copy import deepcopy 

12from numpy import ndarray 

13import jax.numpy as jnp 

14import jax.scipy as jsp 

15 

16from jaxquantum.core.settings import SETTINGS 

17from jaxquantum.utils.utils import robust_isscalar 

18from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims 

19 

20config.update("jax_enable_x64", True) 

21 

22 

23def tidy_up(data, atol): 

24 data_re = jnp.real(data) 

25 data_im = jnp.imag(data) 

26 data_re_mask = jnp.abs(data_re) > atol 

27 data_im_mask = jnp.abs(data_im) > atol 

28 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask 

29 return data_new 

30 

31 

32@struct.dataclass # this allows us to send in and return Qarray from jitted functions 

33class Qarray: 

34 _data: Array 

35 _qdims: Qdims = struct.field(pytree_node=False) 

36 _bdims: tuple[int] = struct.field(pytree_node=False) 

37 

38 # Initialization ---- 

39 @classmethod 

40 def create(cls, data, dims=None, bdims=None): 

41 # Step 1: Prepare data ---- 

42 data = jnp.asarray(data) 

43 

44 if len(data.shape) == 1 and data.shape[0] > 0: 

45 data = data.reshape(data.shape[0], 1) 

46 

47 if len(data.shape) >= 2: 

48 if data.shape[-2] != data.shape[-1] and not ( 

49 data.shape[-2] == 1 or data.shape[-1] == 1 

50 ): 

51 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

52 

53 if bdims is not None: 

54 if len(data.shape) - len(bdims) == 1: 

55 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

56 # ---- 

57 

58 # Step 2: Prepare dimensions ---- 

59 if bdims is None: 

60 bdims = tuple(data.shape[:-2]) 

61 

62 if dims is None: 

63 dims = ((data.shape[-2],), (data.shape[-1],)) 

64 

65 dims = (tuple(dims[0]), tuple(dims[1])) 

66 

67 check_dims(dims, bdims, data.shape) 

68 

69 qdims = Qdims(dims) 

70 

71 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill. 

72 # It increases the compilation time, but only very slightly 

73 # increased the runtime of the jit compiled function. 

74 # We could instead use this tidy_up where we think we need it. 

75 data = tidy_up(data, SETTINGS["auto_tidyup_atol"]) 

76 

77 return cls(data, qdims, bdims) 

78 

79 # ---- 

80 

81 @classmethod 

82 def from_list(cls, qarr_list: List[Qarray]) -> Qarray: 

83 """Create a Qarray from a list of Qarrays.""" 

84 

85 data = jnp.array([qarr.data for qarr in qarr_list]) 

86 

87 if len(qarr_list) == 0: 

88 dims = ((), ()) 

89 bdims = () 

90 else: 

91 dims = qarr_list[0].dims 

92 bdims = qarr_list[0].bdims 

93 

94 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list): 

95 raise ValueError("All Qarrays in the list must have the same dimensions.") 

96 

97 bdims = (len(qarr_list),) + bdims 

98 

99 return cls.create(data, dims=dims, bdims=bdims) 

100 

101 @classmethod 

102 def from_array(cls, qarr_arr) -> Qarray: 

103 """Create a Qarray from a nested list of Qarrays. 

104 

105 Args: 

106 qarr_arr (list): nested list of Qarrays 

107 

108 Returns: 

109 Qarray: Qarray object 

110 """ 

111 if isinstance(qarr_arr, Qarray): 

112 return qarr_arr 

113 

114 bdims = () 

115 lvl = qarr_arr 

116 while not isinstance(lvl, Qarray): 

117 bdims = bdims + (len(lvl),) 

118 if len(lvl) > 0: 

119 lvl = lvl[0] 

120 else: 

121 break 

122 

123 def flat(lis): 

124 flatList = [] 

125 for element in lis: 

126 if type(element) is list: 

127 flatList += flat(element) 

128 else: 

129 flatList.append(element) 

130 return flatList 

131 

132 qarr_list = flat(qarr_arr) 

133 qarr = cls.from_list(qarr_list) 

134 qarr = qarr.reshape_bdims(*bdims) 

135 return qarr 

136 

137 # Properties ---- 

138 @property 

139 def qtype(self): 

140 return self._qdims.qtype 

141 

142 @property 

143 def dtype(self): 

144 return self._data.dtype 

145 

146 @property 

147 def dims(self): 

148 return self._qdims.dims 

149 

150 @property 

151 def bdims(self): 

152 return self._bdims 

153 

154 @property 

155 def qdims(self): 

156 return self._qdims 

157 

158 @property 

159 def space_dims(self): 

160 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

161 return self.dims[0] 

162 elif self.qtype == Qtypes.bra: 

163 return self.dims[1] 

164 else: 

165 # TODO: not reached for some reason 

166 raise ValueError("Unsupported qtype.") 

167 

168 @property 

169 def data(self): 

170 return self._data 

171 

172 @property 

173 def shaped_data(self): 

174 return self._data.reshape(self.bdims + self.dims[0] + self.dims[1]) 

175 

176 @property 

177 def shape(self): 

178 return self.data.shape 

179 

180 @property 

181 def is_batched(self): 

182 return len(self.bdims) > 0 

183 

184 def __getitem__(self, index): 

185 if len(self.bdims) > 0: 

186 return Qarray.create( 

187 self.data[index], 

188 dims=self.dims, 

189 ) 

190 else: 

191 raise ValueError("Cannot index a non-batched Qarray.") 

192 

193 def reshape_bdims(self, *args): 

194 """Reshape the batch dimensions of the Qarray.""" 

195 new_bdims = tuple(args) 

196 

197 if prod(new_bdims) == 0: 

198 new_shape = new_bdims 

199 else: 

200 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,) 

201 return Qarray.create( 

202 self.data.reshape(new_shape), 

203 dims=self.dims, 

204 bdims=new_bdims, 

205 ) 

206 

207 def space_to_qdims(self, space_dims: List[int]): 

208 if isinstance(space_dims[0], (list, tuple)): 

209 return space_dims 

210 

211 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

212 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))])) 

213 elif self.qtype == Qtypes.bra: 

214 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims)) 

215 else: 

216 raise ValueError("Unsupported qtype for space_to_qdims conversion.") 

217 

218 def reshape_qdims(self, *args): 

219 """Reshape the quantum dimensions of the Qarray. 

220 

221 Note that this does not take in qdims but rather the new Hilbert space dimensions. 

222 

223 Args: 

224 *args: new Hilbert dimensions for the Qarray. 

225 

226 Returns: 

227 Qarray: reshaped Qarray. 

228 """ 

229 

230 new_space_dims = tuple(args) 

231 current_space_dims = self.space_dims 

232 assert prod(new_space_dims) == prod(current_space_dims) 

233 

234 new_qdims = self.space_to_qdims(new_space_dims) 

235 new_bdims = self.bdims 

236 

237 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims) 

238 

239 def resize(self, new_shape): 

240 """Resize the Qarray to a new shape. 

241 

242 TODO: review and maybe deprecate this method. 

243 """ 

244 dims = self.dims 

245 data = jnp.resize(self.data, new_shape) 

246 return Qarray.create( 

247 data, 

248 dims=dims, 

249 ) 

250 

251 def __len__(self): 

252 """Length of the Qarray.""" 

253 if len(self.bdims) > 0: 

254 return self.data.shape[0] 

255 else: 

256 raise ValueError("Cannot get length of a non-batched Qarray.") 

257 

258 def __eq__(self, other): 

259 if not isinstance(other, Qarray): 

260 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.") 

261 

262 if self.dims != other.dims: 

263 return False 

264 

265 if self.bdims != other.bdims: 

266 return False 

267 

268 return jnp.all(self.data == other.data) 

269 

270 def __ne__(self, other): 

271 return not self.__eq__(other) 

272 

273 # ---- 

274 

275 # Elementary Math ---- 

276 def __matmul__(self, other): 

277 if not isinstance(other, Qarray): 

278 return NotImplemented 

279 _qdims_new = self._qdims @ other._qdims 

280 return Qarray.create( 

281 self.data @ other.data, 

282 dims=_qdims_new.dims, 

283 ) 

284 

285 # NOTE: not possible to reach this. 

286 # def __rmatmul__(self, other): 

287 # if not isinstance(other, Qarray): 

288 # return NotImplemented 

289 

290 # _qdims_new = other._qdims @ self._qdims 

291 # return Qarray.create( 

292 # other.data @ self.data, 

293 # dims=_qdims_new.dims, 

294 # ) 

295 

296 def __mul__(self, other): 

297 if isinstance(other, Qarray): 

298 return self.__matmul__(other) 

299 

300 other = other + 0.0j 

301 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

302 other = other.reshape(other.shape + (1, 1)) 

303 

304 return Qarray.create( 

305 other * self.data, 

306 dims=self._qdims.dims, 

307 ) 

308 

309 def __rmul__(self, other): 

310 # NOTE: not possible to reach this. 

311 # if isinstance(other, Qarray): 

312 # return self.__rmatmul__(other) 

313 

314 return self.__mul__(other) 

315 

316 def __neg__(self): 

317 return self.__mul__(-1) 

318 

319 def __truediv__(self, other): 

320 """For Qarray's, this only really makes sense in the context of division by a scalar.""" 

321 

322 if isinstance(other, Qarray): 

323 raise ValueError("Cannot divide a Qarray by another Qarray.") 

324 

325 return self.__mul__(1 / other) 

326 

327 def __add__(self, other): 

328 if isinstance(other, Qarray): 

329 if self.dims != other.dims: 

330 msg = ( 

331 "Dimensions are incompatible: " 

332 + repr(self.dims) 

333 + " and " 

334 + repr(other.dims) 

335 ) 

336 raise ValueError(msg) 

337 return Qarray.create(self.data + other.data, dims=self.dims) 

338 

339 if robust_isscalar(other) and other == 0: 

340 return self.copy() 

341 

342 if self.data.shape[-2] == self.data.shape[-1]: 

343 other = other + 0.0j 

344 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

345 other = other.reshape(other.shape + (1, 1)) 

346 other = Qarray.create( 

347 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype), 

348 dims=self.dims, 

349 ) 

350 return self.__add__(other) 

351 

352 return NotImplemented 

353 

354 def __radd__(self, other): 

355 return self.__add__(other) 

356 

357 def __sub__(self, other): 

358 if isinstance(other, Qarray): 

359 if self.dims != other.dims: 

360 msg = ( 

361 "Dimensions are incompatible: " 

362 + repr(self.dims) 

363 + " and " 

364 + repr(other.dims) 

365 ) 

366 raise ValueError(msg) 

367 return Qarray.create(self.data - other.data, dims=self.dims) 

368 

369 if robust_isscalar(other) and other == 0: 

370 return self.copy() 

371 

372 if self.data.shape[-2] == self.data.shape[-1]: 

373 other = other + 0.0j 

374 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

375 other = other.reshape(other.shape + (1, 1)) 

376 other = Qarray.create( 

377 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype), 

378 dims=self.dims, 

379 ) 

380 return self.__sub__(other) 

381 

382 return NotImplemented 

383 

384 def __rsub__(self, other): 

385 return self.__neg__().__add__(other) 

386 

387 def __xor__(self, other): 

388 if not isinstance(other, Qarray): 

389 return NotImplemented 

390 return tensor(self, other) 

391 

392 def __rxor__(self, other): 

393 if not isinstance(other, Qarray): 

394 return NotImplemented 

395 return tensor(other, self) 

396 

397 def __pow__(self, other): 

398 if not isinstance(other, int): 

399 return NotImplemented 

400 

401 return powm(self, other) 

402 

403 # ---- 

404 

405 # String Representation ---- 

406 def _str_header(self): 

407 out = ", ".join( 

408 [ 

409 "Quantum array: dims = " + str(self.dims), 

410 "bdims = " + str(self.bdims), 

411 "shape = " + str(self._data.shape), 

412 "type = " + str(self.qtype), 

413 ] 

414 ) 

415 return out 

416 

417 def __str__(self): 

418 return self._str_header() + "\nQarray data =\n" + str(self.data) 

419 

420 @property 

421 def header(self): 

422 """Print the header of the Qarray.""" 

423 return self._str_header() 

424 

425 def __repr__(self): 

426 return self.__str__() 

427 

428 # ---- 

429 

430 # Utilities ---- 

431 def copy(self, memo=None): 

432 # return Qarray.create(deepcopy(self.data), dims=self.dims) 

433 return self.__deepcopy__(memo) 

434 

435 def __deepcopy__(self, memo): 

436 """Need to override this when defininig __getattr__.""" 

437 

438 return Qarray( 

439 _data=deepcopy(self._data, memo=memo), 

440 _qdims=deepcopy(self._qdims, memo=memo), 

441 _bdims=deepcopy(self._bdims, memo=memo), 

442 ) 

443 

444 def __getattr__(self, method_name): 

445 if "__" == method_name[:2]: 

446 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__ 

447 return lambda *args, **kwargs: NotImplemented 

448 

449 modules = [jnp, jnp.linalg, jsp, jsp.linalg] 

450 

451 method_f = None 

452 for mod in modules: 

453 method_f = getattr(mod, method_name, None) 

454 if method_f is not None: 

455 break 

456 

457 if method_f is None: 

458 raise NotImplementedError( 

459 f"Method {method_name} does not exist. No backup method found in {modules}." 

460 ) 

461 

462 def func(*args, **kwargs): 

463 res = method_f(self.data, *args, **kwargs) 

464 

465 if getattr(res, "shape", None) is None or res.shape != self.data.shape: 

466 return res 

467 else: 

468 return Qarray.create(res, dims=self._qdims.dims) 

469 

470 return func 

471 

472 # ---- 

473 

474 # Conversions / Reshaping ---- 

475 def dag(self): 

476 return dag(self) 

477 

478 def to_dm(self): 

479 return ket2dm(self) 

480 

481 def is_dm(self): 

482 return self.qtype == Qtypes.oper 

483 

484 def is_vec(self): 

485 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra 

486 

487 def to_ket(self): 

488 return to_ket(self) 

489 

490 def transpose(self, *args): 

491 return transpose(self, *args) 

492 

493 def keep_only_diag_elements(self): 

494 return keep_only_diag_elements(self) 

495 

496 # ---- 

497 

498 # Math Functions ---- 

499 def unit(self): 

500 return unit(self) 

501 

502 def norm(self): 

503 return norm(self) 

504 

505 def expm(self): 

506 return expm(self) 

507 

508 def powm(self, n): 

509 return powm(self, n) 

510 

511 def cosm(self): 

512 return cosm(self) 

513 

514 def sinm(self): 

515 return sinm(self) 

516 

517 def tr(self, **kwargs): 

518 return tr(self, **kwargs) 

519 

520 def trace(self, **kwargs): 

521 return tr(self, **kwargs) 

522 

523 def ptrace(self, indx): 

524 return ptrace(self, indx) 

525 

526 def eigenstates(self): 

527 return eigenstates(self) 

528 

529 def eigenenergies(self): 

530 return eigenenergies(self) 

531 

532 def collapse(self, mode="sum"): 

533 return collapse(self, mode=mode) 

534 

535 # ---- 

536 

537 

538ARRAY_TYPES = (Array, ndarray, Qarray) 

539 

540# Qarray operations --------------------------------------------------------------------- 

541 

542 

543def collapse(qarr: Qarray, mode="sum") -> Qarray: 

544 """Collapse the Qarray. 

545 

546 Args: 

547 qarr (Qarray): quantum array array 

548 

549 Returns: 

550 Collapsed quantum array 

551 """ 

552 if mode == "sum": 

553 if len(qarr.bdims) == 0: 

554 return qarr 

555 

556 batch_axes = list(range(len(qarr.bdims))) 

557 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims) 

558 

559 

560def transpose(qarr: Qarray, indices: List[int]) -> Qarray: 

561 """Transpose the quantum array. 

562 

563 Args: 

564 qarr (Qarray): quantum array 

565 *args: axes to transpose 

566 

567 Returns: 

568 tranposed Qarray 

569 """ 

570 

571 indices = list(indices) 

572 

573 shaped_data = qarr.shaped_data 

574 dims = qarr.dims 

575 bdims_indxs = list(range(len(qarr.bdims))) 

576 

577 reshape_indices = indices + [j + len(dims[0]) for j in indices] 

578 

579 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices] 

580 

581 shaped_data = shaped_data.transpose(reshape_indices) 

582 new_dims = ( 

583 tuple([dims[0][j] for j in indices]), 

584 tuple([dims[1][j] for j in indices]), 

585 ) 

586 

587 full_dims = prod(dims[0]) 

588 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1) 

589 return Qarray.create(full_data, dims=new_dims) 

590 

591 

592def unit(qarr: Qarray) -> Qarray: 

593 """Normalize the quantum array. 

594 

595 Args: 

596 qarr (Qarray): quantum array 

597 

598 Returns: 

599 Normalized quantum array 

600 """ 

601 data = qarr.data 

602 data = data / qarr.norm() 

603 return Qarray.create(data, dims=qarr.dims) 

604 

605 

606def norm(qarr: Qarray) -> float: 

607 data = qarr.data 

608 data_dag = qarr.dag().data 

609 

610 if qarr.qtype == Qtypes.oper: 

611 evals, _ = jnp.linalg.eigh(data @ data_dag) 

612 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals))) 

613 return rho_norm 

614 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]: 

615 return jnp.linalg.norm(data) 

616 

617 

618def tensor(*args, **kwargs) -> Qarray: 

619 """Tensor product. 

620 

621 Args: 

622 *args (Qarray): tensors to take the product of 

623 parallel (bool): if True, use parallel einsum for tensor product 

624 true: [A,B] ^ [C,D] = [A^C, B^D] 

625 false: [A,B] ^ [C,D] = [A^C, A^D, B^C, B^D] 

626 

627 Returns: 

628 Tensor product of given tensors 

629 

630 """ 

631 

632 parallel = kwargs.pop("parallel", False) 

633 

634 data = args[0].data 

635 dims = deepcopy(args[0].dims) 

636 dims_0 = dims[0] 

637 dims_1 = dims[1] 

638 for arg in args[1:]: 

639 if parallel: 

640 a = data 

641 b = arg.data 

642 

643 if len(a.shape) > len(b.shape): 

644 batch_dim = a.shape[:-2] 

645 elif len(a.shape) == len(b.shape): 

646 if prod(a.shape[:-2]) > prod(b.shape[:-2]): 

647 batch_dim = a.shape[:-2] 

648 else: 

649 batch_dim = b.shape[:-2] 

650 else: 

651 batch_dim = b.shape[:-2] 

652 

653 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape( 

654 *batch_dim, a.shape[-2] * b.shape[-2], -1 

655 ) 

656 else: 

657 data = jnp.kron(data, arg.data, **kwargs) 

658 

659 dims_0 = dims_0 + arg.dims[0] 

660 dims_1 = dims_1 + arg.dims[1] 

661 

662 return Qarray.create(data, dims=(dims_0, dims_1)) 

663 

664 

665def tr(qarr: Qarray, **kwargs) -> Array: 

666 """Full trace. 

667 

668 Args: 

669 qarr (Qarray): quantum array 

670 

671 Returns: 

672 Full trace. 

673 """ 

674 axis1 = kwargs.get("axis1", -2) 

675 axis2 = kwargs.get("axis2", -1) 

676 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs) 

677 

678 

679def trace(qarr: Qarray, **kwargs) -> Array: 

680 """Full trace. 

681 

682 Args: 

683 qarr (Qarray): quantum array 

684 

685 Returns: 

686 Full trace. 

687 """ 

688 return tr(qarr, **kwargs) 

689 

690 

691def expm_data(data: Array, **kwargs) -> Array: 

692 """Matrix exponential wrapper. 

693 

694 Returns: 

695 matrix exponential 

696 """ 

697 return jsp.linalg.expm(data, **kwargs) 

698 

699 

700def expm(qarr: Qarray, **kwargs) -> Qarray: 

701 """Matrix exponential wrapper. 

702 

703 Returns: 

704 matrix exponential 

705 """ 

706 dims = qarr.dims 

707 data = expm_data(qarr.data, **kwargs) 

708 return Qarray.create(data, dims=dims) 

709 

710 

711def powm(qarr: Qarray, n: Union[int, float]) -> Qarray: 

712 """Matrix power. 

713 

714 Args: 

715 qarr (Qarray): quantum array 

716 n (int): power 

717 

718 Returns: 

719 matrix power 

720 """ 

721 if isinstance(n, int): 

722 data_res = jnp.linalg.matrix_power(qarr.data, n) 

723 else: 

724 evalues, evectors = jnp.linalg.eig(qarr.data) 

725 if not (evalues >= 0).all(): 

726 raise ValueError( 

727 "Non-integer power of a matrix can only be " 

728 "computed if the matrix is positive semi-definite." 

729 "Got a matrix with a negative eigenvalue." 

730 ) 

731 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors) 

732 return Qarray.create(data_res, dims=qarr.dims) 

733 

734 

735def cosm_data(data: Array, **kwargs) -> Array: 

736 """Matrix cosine wrapper. 

737 

738 Returns: 

739 matrix cosine 

740 """ 

741 return (expm_data(1j * data) + expm_data(-1j * data)) / 2 

742 

743 

744def cosm(qarr: Qarray) -> Qarray: 

745 """Matrix cosine wrapper. 

746 

747 Args: 

748 qarr (Qarray): quantum array 

749 

750 Returns: 

751 matrix cosine 

752 """ 

753 dims = qarr.dims 

754 data = cosm_data(qarr.data) 

755 return Qarray.create(data, dims=dims) 

756 

757 

758def sinm_data(data: Array, **kwargs) -> Array: 

759 """Matrix sine wrapper. 

760 

761 Args: 

762 data: matrix 

763 

764 Returns: 

765 matrix sine 

766 """ 

767 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j) 

768 

769 

770def sinm(qarr: Qarray) -> Qarray: 

771 dims = qarr.dims 

772 data = sinm_data(qarr.data) 

773 return Qarray.create(data, dims=dims) 

774 

775 

776def keep_only_diag_elements(qarr: Qarray) -> Qarray: 

777 if len(qarr.bdims) > 0: 

778 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.") 

779 

780 dims = qarr.dims 

781 data = jnp.diag(jnp.diag(qarr.data)) 

782 return Qarray.create(data, dims=dims) 

783 

784 

785def to_ket(qarr: Qarray) -> Qarray: 

786 if qarr.qtype == Qtypes.ket: 

787 return qarr 

788 elif qarr.qtype == Qtypes.bra: 

789 return qarr.dag() 

790 else: 

791 raise ValueError("Can only get ket from a ket or bra.") 

792 

793 

794def eigenstates(qarr: Qarray) -> Qarray: 

795 """Eigenstates of a quantum array. 

796 

797 Args: 

798 qarr (Qarray): quantum array 

799 

800 Returns: 

801 eigenvalues and eigenstates 

802 """ 

803 

804 evals, evecs = jnp.linalg.eigh(qarr.data) 

805 idxs_sorted = jnp.argsort(evals, axis=-1) 

806 

807 dims = ket_from_op_dims(qarr.dims) 

808 

809 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1) 

810 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1) 

811 

812 evecs = Qarray.create( 

813 evecs, 

814 dims=dims, 

815 bdims=evecs.shape[:-1], 

816 ) 

817 

818 return evals, evecs 

819 

820 

821def eigenenergies(qarr: Qarray) -> Array: 

822 """Eigenvalues of a quantum array. 

823 

824 Args: 

825 qarr (Qarray): quantum array 

826 

827 Returns: 

828 eigenvalues 

829 """ 

830 

831 evals = jnp.linalg.eigvalsh(qarr.data) 

832 return evals 

833 

834 

835# More quantum specific ----------------------------------------------------- 

836 

837 

838def ptrace(qarr: Qarray, indx) -> Qarray: 

839 """Partial Trace. 

840 

841 Args: 

842 rho: density matrix 

843 indx: index of quantum object to keep, rest will be partial traced out 

844 

845 Returns: 

846 partial traced out density matrix 

847 

848 TODO: Fix weird tracing errors that arise with reshape 

849 """ 

850 

851 qarr = ket2dm(qarr) 

852 rho = qarr.shaped_data 

853 dims = qarr.dims 

854 

855 Nq = len(dims[0]) 

856 

857 indxs = [indx, indx + Nq] 

858 for j in range(Nq): 

859 if j == indx: 

860 continue 

861 indxs.append(j) 

862 indxs.append(j + Nq) 

863 

864 bdims = qarr.bdims 

865 len_bdims = len(bdims) 

866 bdims_indxs = list(range(len_bdims)) 

867 indxs = bdims_indxs + [j + len_bdims for j in indxs] 

868 rho = rho.transpose(indxs) 

869 

870 for j in range(Nq - 1): 

871 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims) 

872 

873 return Qarray.create(rho) 

874 

875 

876def dag(qarr: Qarray) -> Qarray: 

877 """Conjugate transpose. 

878 

879 Args: 

880 qarr (Qarray): quantum array 

881 

882 Returns: 

883 conjugate transpose of qarr 

884 """ 

885 dims = qarr.dims[::-1] 

886 

887 data = dag_data(qarr.data) 

888 

889 return Qarray.create(data, dims=dims) 

890 

891 

892def dag_data(arr: Array) -> Array: 

893 """Conjugate transpose. 

894 

895 Args: 

896 arr: operator 

897 

898 Returns: 

899 conjugate of op, and transposes last two axes 

900 """ 

901 # TODO: revisit this case... 

902 if len(arr.shape) == 1: 

903 return jnp.conj(arr) 

904 

905 return jnp.moveaxis( 

906 jnp.conj(arr), -1, -2 

907 ) # transposes last two axes, good for batching 

908 

909 

910def ket2dm(qarr: Qarray) -> Qarray: 

911 """Turns ket into density matrix. 

912 Does nothing if already operator. 

913 

914 Args: 

915 qarr (Qarray): qarr 

916 

917 Returns: 

918 Density matrix 

919 """ 

920 

921 if qarr.qtype == Qtypes.oper: 

922 return qarr 

923 

924 if qarr.qtype == Qtypes.bra: 

925 qarr = qarr.dag() 

926 

927 return qarr @ qarr.dag() 

928 

929 

930# Data level operations ---- 

931 

932 

933def is_dm_data(data: Array) -> bool: 

934 """Check if data is a density matrix. 

935 

936 Args: 

937 data: matrix 

938 Returns: 

939 True if data is a density matrix 

940 """ 

941 return data.shape[-2] == data.shape[-1] 

942 

943 

944def powm_data(data: Array, n: int) -> Array: 

945 """Matrix power. 

946 

947 Args: 

948 data: matrix 

949 n: power 

950 

951 Returns: 

952 matrix power 

953 """ 

954 return jnp.linalg.matrix_power(data, n)