Coverage for jaxquantum/core/qarray.py: 75%

452 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-19 17:34 +0000

1"""QArray.""" 

2 

3from __future__ import annotations 

4 

5from flax import struct 

6from jax import Array, config 

7from typing import List, Union 

8 

9 

10from math import prod 

11from copy import deepcopy 

12from numpy import ndarray 

13import jax.numpy as jnp 

14import jax.scipy as jsp 

15 

16from jaxquantum.core.settings import SETTINGS 

17from jaxquantum.utils.utils import robust_isscalar 

18from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims 

19 

20config.update("jax_enable_x64", True) 

21 

22 

23def tidy_up(data, atol): 

24 data_re = jnp.real(data) 

25 data_im = jnp.imag(data) 

26 data_re_mask = jnp.abs(data_re) > atol 

27 data_im_mask = jnp.abs(data_im) > atol 

28 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask 

29 return data_new 

30 

31 

32@struct.dataclass # this allows us to send in and return Qarray from jitted functions 

33class Qarray: 

34 _data: Array 

35 _qdims: Qdims = struct.field(pytree_node=False) 

36 _bdims: tuple[int] = struct.field(pytree_node=False) 

37 

38 # Initialization ---- 

39 @classmethod 

40 def create(cls, data, dims=None, bdims=None): 

41 # Step 1: Prepare data ---- 

42 data = jnp.asarray(data) 

43 

44 if len(data.shape) == 1 and data.shape[0] > 0: 

45 data = data.reshape(data.shape[0], 1) 

46 

47 if len(data.shape) >= 2: 

48 if data.shape[-2] != data.shape[-1] and not ( 

49 data.shape[-2] == 1 or data.shape[-1] == 1 

50 ): 

51 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

52 

53 if bdims is not None: 

54 if len(data.shape) - len(bdims) == 1: 

55 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

56 # ---- 

57 

58 # Step 2: Prepare dimensions ---- 

59 if bdims is None: 

60 bdims = tuple(data.shape[:-2]) 

61 

62 if dims is None: 

63 dims = ((data.shape[-2],), (data.shape[-1],)) 

64 

65 if not isinstance(dims[0], (list, tuple)): 

66 # This handles the case where only the hilbert space dimensions are sent in. 

67 if data.shape[-1] == 1: 

68 dims = (tuple(dims), tuple([1 for _ in dims])) 

69 elif data.shape[-2] == 1: 

70 dims = (tuple([1 for _ in dims]), tuple(dims)) 

71 else: 

72 dims = (tuple(dims), tuple(dims)) 

73 else: 

74 dims = (tuple(dims[0]), tuple(dims[1])) 

75 

76 check_dims(dims, bdims, data.shape) 

77 

78 qdims = Qdims(dims) 

79 

80 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill. 

81 # It increases the compilation time, but only very slightly 

82 # increased the runtime of the jit compiled function. 

83 # We could instead use this tidy_up where we think we need it. 

84 data = tidy_up(data, SETTINGS["auto_tidyup_atol"]) 

85 

86 return cls(data, qdims, bdims) 

87 

88 # ---- 

89 

90 @classmethod 

91 def from_list(cls, qarr_list: List[Qarray]) -> Qarray: 

92 """Create a Qarray from a list of Qarrays.""" 

93 

94 data = jnp.array([qarr.data for qarr in qarr_list]) 

95 

96 if len(qarr_list) == 0: 

97 dims = ((), ()) 

98 bdims = () 

99 else: 

100 dims = qarr_list[0].dims 

101 bdims = qarr_list[0].bdims 

102 

103 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list): 

104 raise ValueError("All Qarrays in the list must have the same dimensions.") 

105 

106 bdims = (len(qarr_list),) + bdims 

107 

108 return cls.create(data, dims=dims, bdims=bdims) 

109 

110 @classmethod 

111 def from_array(cls, qarr_arr) -> Qarray: 

112 """Create a Qarray from a nested list of Qarrays. 

113 

114 Args: 

115 qarr_arr (list): nested list of Qarrays 

116 

117 Returns: 

118 Qarray: Qarray object 

119 """ 

120 if isinstance(qarr_arr, Qarray): 

121 return qarr_arr 

122 

123 bdims = () 

124 lvl = qarr_arr 

125 while not isinstance(lvl, Qarray): 

126 bdims = bdims + (len(lvl),) 

127 if len(lvl) > 0: 

128 lvl = lvl[0] 

129 else: 

130 break 

131 

132 def flat(lis): 

133 flatList = [] 

134 for element in lis: 

135 if type(element) is list: 

136 flatList += flat(element) 

137 else: 

138 flatList.append(element) 

139 return flatList 

140 

141 qarr_list = flat(qarr_arr) 

142 qarr = cls.from_list(qarr_list) 

143 qarr = qarr.reshape_bdims(*bdims) 

144 return qarr 

145 

146 # Properties ---- 

147 @property 

148 def qtype(self): 

149 return self._qdims.qtype 

150 

151 @property 

152 def dtype(self): 

153 return self._data.dtype 

154 

155 @property 

156 def dims(self): 

157 return self._qdims.dims 

158 

159 @property 

160 def bdims(self): 

161 return self._bdims 

162 

163 @property 

164 def qdims(self): 

165 return self._qdims 

166 

167 @property 

168 def space_dims(self): 

169 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

170 return self.dims[0] 

171 elif self.qtype == Qtypes.bra: 

172 return self.dims[1] 

173 else: 

174 # TODO: not reached for some reason 

175 raise ValueError("Unsupported qtype.") 

176 

177 @property 

178 def data(self): 

179 return self._data 

180 

181 @property 

182 def shaped_data(self): 

183 return self._data.reshape(self.bdims + self.dims[0] + self.dims[1]) 

184 

185 @property 

186 def shape(self): 

187 return self.data.shape 

188 

189 @property 

190 def is_batched(self): 

191 return len(self.bdims) > 0 

192 

193 def __getitem__(self, index): 

194 if len(self.bdims) > 0: 

195 return Qarray.create( 

196 self.data[index], 

197 dims=self.dims, 

198 ) 

199 else: 

200 raise ValueError("Cannot index a non-batched Qarray.") 

201 

202 def reshape_bdims(self, *args): 

203 """Reshape the batch dimensions of the Qarray.""" 

204 new_bdims = tuple(args) 

205 

206 if prod(new_bdims) == 0: 

207 new_shape = new_bdims 

208 else: 

209 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,) 

210 return Qarray.create( 

211 self.data.reshape(new_shape), 

212 dims=self.dims, 

213 bdims=new_bdims, 

214 ) 

215 

216 def space_to_qdims(self, space_dims: List[int]): 

217 if isinstance(space_dims[0], (list, tuple)): 

218 return space_dims 

219 

220 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

221 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))])) 

222 elif self.qtype == Qtypes.bra: 

223 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims)) 

224 else: 

225 raise ValueError("Unsupported qtype for space_to_qdims conversion.") 

226 

227 def reshape_qdims(self, *args): 

228 """Reshape the quantum dimensions of the Qarray. 

229 

230 Note that this does not take in qdims but rather the new Hilbert space dimensions. 

231 

232 Args: 

233 *args: new Hilbert dimensions for the Qarray. 

234 

235 Returns: 

236 Qarray: reshaped Qarray. 

237 """ 

238 

239 new_space_dims = tuple(args) 

240 current_space_dims = self.space_dims 

241 assert prod(new_space_dims) == prod(current_space_dims) 

242 

243 new_qdims = self.space_to_qdims(new_space_dims) 

244 new_bdims = self.bdims 

245 

246 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims) 

247 

248 def resize(self, new_shape): 

249 """Resize the Qarray to a new shape. 

250 

251 TODO: review and maybe deprecate this method. 

252 """ 

253 dims = self.dims 

254 data = jnp.resize(self.data, new_shape) 

255 return Qarray.create( 

256 data, 

257 dims=dims, 

258 ) 

259 

260 def __len__(self): 

261 """Length of the Qarray.""" 

262 if len(self.bdims) > 0: 

263 return self.data.shape[0] 

264 else: 

265 raise ValueError("Cannot get length of a non-batched Qarray.") 

266 

267 def __eq__(self, other): 

268 if not isinstance(other, Qarray): 

269 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.") 

270 

271 if self.dims != other.dims: 

272 return False 

273 

274 if self.bdims != other.bdims: 

275 return False 

276 

277 return jnp.all(self.data == other.data) 

278 

279 def __ne__(self, other): 

280 return not self.__eq__(other) 

281 

282 # ---- 

283 

284 # Elementary Math ---- 

285 def __matmul__(self, other): 

286 if not isinstance(other, Qarray): 

287 return NotImplemented 

288 _qdims_new = self._qdims @ other._qdims 

289 return Qarray.create( 

290 self.data @ other.data, 

291 dims=_qdims_new.dims, 

292 ) 

293 

294 # NOTE: not possible to reach this. 

295 # def __rmatmul__(self, other): 

296 # if not isinstance(other, Qarray): 

297 # return NotImplemented 

298 

299 # _qdims_new = other._qdims @ self._qdims 

300 # return Qarray.create( 

301 # other.data @ self.data, 

302 # dims=_qdims_new.dims, 

303 # ) 

304 

305 def __mul__(self, other): 

306 if isinstance(other, Qarray): 

307 return self.__matmul__(other) 

308 

309 other = other + 0.0j 

310 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

311 other = other.reshape(other.shape + (1, 1)) 

312 

313 return Qarray.create( 

314 other * self.data, 

315 dims=self._qdims.dims, 

316 ) 

317 

318 def __rmul__(self, other): 

319 # NOTE: not possible to reach this. 

320 # if isinstance(other, Qarray): 

321 # return self.__rmatmul__(other) 

322 

323 return self.__mul__(other) 

324 

325 def __neg__(self): 

326 return self.__mul__(-1) 

327 

328 def __truediv__(self, other): 

329 """For Qarray's, this only really makes sense in the context of division by a scalar.""" 

330 

331 if isinstance(other, Qarray): 

332 raise ValueError("Cannot divide a Qarray by another Qarray.") 

333 

334 return self.__mul__(1 / other) 

335 

336 def __add__(self, other): 

337 if isinstance(other, Qarray): 

338 if self.dims != other.dims: 

339 msg = ( 

340 "Dimensions are incompatible: " 

341 + repr(self.dims) 

342 + " and " 

343 + repr(other.dims) 

344 ) 

345 raise ValueError(msg) 

346 return Qarray.create(self.data + other.data, dims=self.dims) 

347 

348 if robust_isscalar(other) and other == 0: 

349 return self.copy() 

350 

351 if self.data.shape[-2] == self.data.shape[-1]: 

352 other = other + 0.0j 

353 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

354 other = other.reshape(other.shape + (1, 1)) 

355 other = Qarray.create( 

356 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype), 

357 dims=self.dims, 

358 ) 

359 return self.__add__(other) 

360 

361 return NotImplemented 

362 

363 def __radd__(self, other): 

364 return self.__add__(other) 

365 

366 def __sub__(self, other): 

367 if isinstance(other, Qarray): 

368 if self.dims != other.dims: 

369 msg = ( 

370 "Dimensions are incompatible: " 

371 + repr(self.dims) 

372 + " and " 

373 + repr(other.dims) 

374 ) 

375 raise ValueError(msg) 

376 return Qarray.create(self.data - other.data, dims=self.dims) 

377 

378 if robust_isscalar(other) and other == 0: 

379 return self.copy() 

380 

381 if self.data.shape[-2] == self.data.shape[-1]: 

382 other = other + 0.0j 

383 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

384 other = other.reshape(other.shape + (1, 1)) 

385 other = Qarray.create( 

386 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype), 

387 dims=self.dims, 

388 ) 

389 return self.__sub__(other) 

390 

391 return NotImplemented 

392 

393 def __rsub__(self, other): 

394 return self.__neg__().__add__(other) 

395 

396 def __xor__(self, other): 

397 if not isinstance(other, Qarray): 

398 return NotImplemented 

399 return tensor(self, other) 

400 

401 def __rxor__(self, other): 

402 if not isinstance(other, Qarray): 

403 return NotImplemented 

404 return tensor(other, self) 

405 

406 def __pow__(self, other): 

407 if not isinstance(other, int): 

408 return NotImplemented 

409 

410 return powm(self, other) 

411 

412 # ---- 

413 

414 # String Representation ---- 

415 def _str_header(self): 

416 out = ", ".join( 

417 [ 

418 "Quantum array: dims = " + str(self.dims), 

419 "bdims = " + str(self.bdims), 

420 "shape = " + str(self._data.shape), 

421 "type = " + str(self.qtype), 

422 ] 

423 ) 

424 return out 

425 

426 def __str__(self): 

427 return self._str_header() + "\nQarray data =\n" + str(self.data) 

428 

429 @property 

430 def header(self): 

431 """Print the header of the Qarray.""" 

432 return self._str_header() 

433 

434 def __repr__(self): 

435 return self.__str__() 

436 

437 # ---- 

438 

439 # Utilities ---- 

440 def copy(self, memo=None): 

441 # return Qarray.create(deepcopy(self.data), dims=self.dims) 

442 return self.__deepcopy__(memo) 

443 

444 def __deepcopy__(self, memo): 

445 """Need to override this when defininig __getattr__.""" 

446 

447 return Qarray( 

448 _data=deepcopy(self._data, memo=memo), 

449 _qdims=deepcopy(self._qdims, memo=memo), 

450 _bdims=deepcopy(self._bdims, memo=memo), 

451 ) 

452 

453 def __getattr__(self, method_name): 

454 if "__" == method_name[:2]: 

455 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__ 

456 return lambda *args, **kwargs: NotImplemented 

457 

458 modules = [jnp, jnp.linalg, jsp, jsp.linalg] 

459 

460 method_f = None 

461 for mod in modules: 

462 method_f = getattr(mod, method_name, None) 

463 if method_f is not None: 

464 break 

465 

466 if method_f is None: 

467 raise NotImplementedError( 

468 f"Method {method_name} does not exist. No backup method found in {modules}." 

469 ) 

470 

471 def func(*args, **kwargs): 

472 res = method_f(self.data, *args, **kwargs) 

473 

474 if getattr(res, "shape", None) is None or res.shape != self.data.shape: 

475 return res 

476 else: 

477 return Qarray.create(res, dims=self._qdims.dims) 

478 

479 return func 

480 

481 # ---- 

482 

483 # Conversions / Reshaping ---- 

484 def dag(self): 

485 return dag(self) 

486 

487 def to_dm(self): 

488 return ket2dm(self) 

489 

490 def is_dm(self): 

491 return self.qtype == Qtypes.oper 

492 

493 def is_vec(self): 

494 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra 

495 

496 def to_ket(self): 

497 return to_ket(self) 

498 

499 def transpose(self, *args): 

500 return transpose(self, *args) 

501 

502 def keep_only_diag_elements(self): 

503 return keep_only_diag_elements(self) 

504 

505 # ---- 

506 

507 # Math Functions ---- 

508 def unit(self): 

509 return unit(self) 

510 

511 def norm(self): 

512 return norm(self) 

513 

514 def expm(self): 

515 return expm(self) 

516 

517 def powm(self, n): 

518 return powm(self, n) 

519 

520 def cosm(self): 

521 return cosm(self) 

522 

523 def sinm(self): 

524 return sinm(self) 

525 

526 def tr(self, **kwargs): 

527 return tr(self, **kwargs) 

528 

529 def trace(self, **kwargs): 

530 return tr(self, **kwargs) 

531 

532 def ptrace(self, indx): 

533 return ptrace(self, indx) 

534 

535 def eigenstates(self): 

536 return eigenstates(self) 

537 

538 def eigenenergies(self): 

539 return eigenenergies(self) 

540 

541 def collapse(self, mode="sum"): 

542 return collapse(self, mode=mode) 

543 

544 # ---- 

545 

546 

547ARRAY_TYPES = (Array, ndarray, Qarray) 

548 

549# Qarray operations --------------------------------------------------------------------- 

550 

551def concatenate(qarr_list: List[Qarray], axis: int = 0) -> Qarray: 

552 """Concatenate a list of Qarrays along a specified axis. 

553 

554 Args: 

555 qarr_list (List[Qarray]): List of Qarrays to concatenate. 

556 axis (int): Axis along which to concatenate. Default is 0. 

557 

558 Returns: 

559 Qarray: Concatenated Qarray. 

560 """ 

561 

562 non_empty_qarr_list = [qarr for qarr in qarr_list if len(qarr.data) != 0] 

563 

564 if len(non_empty_qarr_list) == 0: 

565 return Qarray.from_list([]) 

566 

567 concatenated_data = jnp.concatenate( 

568 [qarr.data for qarr in non_empty_qarr_list], axis=axis 

569 ) 

570 

571 dims = non_empty_qarr_list[0].dims 

572 return Qarray.create(concatenated_data, dims=dims) 

573 

574 

575def collapse(qarr: Qarray, mode="sum") -> Qarray: 

576 """Collapse the Qarray. 

577 

578 Args: 

579 qarr (Qarray): quantum array array 

580 

581 Returns: 

582 Collapsed quantum array 

583 """ 

584 if mode == "sum": 

585 if len(qarr.bdims) == 0: 

586 return qarr 

587 

588 batch_axes = list(range(len(qarr.bdims))) 

589 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims) 

590 

591 

592def transpose(qarr: Qarray, indices: List[int]) -> Qarray: 

593 """Transpose the quantum array. 

594 

595 Args: 

596 qarr (Qarray): quantum array 

597 *args: axes to transpose 

598 

599 Returns: 

600 tranposed Qarray 

601 """ 

602 

603 indices = list(indices) 

604 

605 shaped_data = qarr.shaped_data 

606 dims = qarr.dims 

607 bdims_indxs = list(range(len(qarr.bdims))) 

608 

609 reshape_indices = indices + [j + len(dims[0]) for j in indices] 

610 

611 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices] 

612 

613 shaped_data = shaped_data.transpose(reshape_indices) 

614 new_dims = ( 

615 tuple([dims[0][j] for j in indices]), 

616 tuple([dims[1][j] for j in indices]), 

617 ) 

618 

619 full_dims = prod(dims[0]) 

620 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1) 

621 return Qarray.create(full_data, dims=new_dims) 

622 

623 

624def unit(qarr: Qarray) -> Qarray: 

625 """Normalize the quantum array. 

626 

627 Args: 

628 qarr (Qarray): quantum array 

629 

630 Returns: 

631 Normalized quantum array 

632 """ 

633 data = qarr.data 

634 data = data / qarr.norm() 

635 return Qarray.create(data, dims=qarr.dims) 

636 

637 

638def norm(qarr: Qarray) -> float: 

639 data = qarr.data 

640 data_dag = qarr.dag().data 

641 

642 if qarr.qtype == Qtypes.oper: 

643 evals, _ = jnp.linalg.eigh(data @ data_dag) 

644 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals))) 

645 return rho_norm 

646 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]: 

647 return jnp.linalg.norm(data) 

648 

649 

650def tensor(*args, **kwargs) -> Qarray: 

651 """Tensor product. 

652 

653 Args: 

654 *args (Qarray): tensors to take the product of 

655 parallel (bool): if True, use parallel einsum for tensor product 

656 true: [A,B] ^ [C,D] = [A^C, B^D] 

657 false (default): [A,B] ^ [C,D] = [A^C, A^D, B^C, B^D] 

658 

659 Returns: 

660 Tensor product of given tensors 

661 

662 """ 

663 

664 parallel = kwargs.pop("parallel", False) 

665 

666 data = args[0].data 

667 dims = deepcopy(args[0].dims) 

668 dims_0 = dims[0] 

669 dims_1 = dims[1] 

670 for arg in args[1:]: 

671 if parallel: 

672 a = data 

673 b = arg.data 

674 

675 if len(a.shape) > len(b.shape): 

676 batch_dim = a.shape[:-2] 

677 elif len(a.shape) == len(b.shape): 

678 if prod(a.shape[:-2]) > prod(b.shape[:-2]): 

679 batch_dim = a.shape[:-2] 

680 else: 

681 batch_dim = b.shape[:-2] 

682 else: 

683 batch_dim = b.shape[:-2] 

684 

685 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape( 

686 *batch_dim, a.shape[-2] * b.shape[-2], -1 

687 ) 

688 else: 

689 data = jnp.kron(data, arg.data, **kwargs) 

690 

691 dims_0 = dims_0 + arg.dims[0] 

692 dims_1 = dims_1 + arg.dims[1] 

693 

694 return Qarray.create(data, dims=(dims_0, dims_1)) 

695 

696 

697def tr(qarr: Qarray, **kwargs) -> Array: 

698 """Full trace. 

699 

700 Args: 

701 qarr (Qarray): quantum array 

702 

703 Returns: 

704 Full trace. 

705 """ 

706 axis1 = kwargs.get("axis1", -2) 

707 axis2 = kwargs.get("axis2", -1) 

708 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs) 

709 

710 

711def trace(qarr: Qarray, **kwargs) -> Array: 

712 """Full trace. 

713 

714 Args: 

715 qarr (Qarray): quantum array 

716 

717 Returns: 

718 Full trace. 

719 """ 

720 return tr(qarr, **kwargs) 

721 

722 

723def expm_data(data: Array, **kwargs) -> Array: 

724 """Matrix exponential wrapper. 

725 

726 Returns: 

727 matrix exponential 

728 """ 

729 return jsp.linalg.expm(data, **kwargs) 

730 

731 

732def expm(qarr: Qarray, **kwargs) -> Qarray: 

733 """Matrix exponential wrapper. 

734 

735 Returns: 

736 matrix exponential 

737 """ 

738 dims = qarr.dims 

739 data = expm_data(qarr.data, **kwargs) 

740 return Qarray.create(data, dims=dims) 

741 

742 

743def powm(qarr: Qarray, n: Union[int, float], clip_eigvals=False) -> Qarray: 

744 """Matrix power. 

745 

746 Args: 

747 qarr (Qarray): quantum array 

748 n (int): power 

749 clip_eigvals (bool): clip eigenvalues to always be able to compute 

750 non-integer powers 

751 

752 Returns: 

753 matrix power 

754 """ 

755 if isinstance(n, int): 

756 data_res = jnp.linalg.matrix_power(qarr.data, n) 

757 else: 

758 evalues, evectors = jnp.linalg.eig(qarr.data) 

759 if clip_eigvals: 

760 evalues = jnp.maximum(evalues, 0) 

761 else: 

762 if not (evalues >= 0).all(): 

763 raise ValueError( 

764 "Non-integer power of a matrix can only be " 

765 "computed if the matrix is positive semi-definite." 

766 "Got a matrix with a negative eigenvalue." 

767 ) 

768 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors) 

769 return Qarray.create(data_res, dims=qarr.dims) 

770 

771 

772def cosm_data(data: Array, **kwargs) -> Array: 

773 """Matrix cosine wrapper. 

774 

775 Returns: 

776 matrix cosine 

777 """ 

778 return (expm_data(1j * data) + expm_data(-1j * data)) / 2 

779 

780 

781def cosm(qarr: Qarray) -> Qarray: 

782 """Matrix cosine wrapper. 

783 

784 Args: 

785 qarr (Qarray): quantum array 

786 

787 Returns: 

788 matrix cosine 

789 """ 

790 dims = qarr.dims 

791 data = cosm_data(qarr.data) 

792 return Qarray.create(data, dims=dims) 

793 

794 

795def sinm_data(data: Array, **kwargs) -> Array: 

796 """Matrix sine wrapper. 

797 

798 Args: 

799 data: matrix 

800 

801 Returns: 

802 matrix sine 

803 """ 

804 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j) 

805 

806 

807def sinm(qarr: Qarray) -> Qarray: 

808 dims = qarr.dims 

809 data = sinm_data(qarr.data) 

810 return Qarray.create(data, dims=dims) 

811 

812 

813def keep_only_diag_elements(qarr: Qarray) -> Qarray: 

814 if len(qarr.bdims) > 0: 

815 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.") 

816 

817 dims = qarr.dims 

818 data = jnp.diag(jnp.diag(qarr.data)) 

819 return Qarray.create(data, dims=dims) 

820 

821 

822def to_ket(qarr: Qarray) -> Qarray: 

823 if qarr.qtype == Qtypes.ket: 

824 return qarr 

825 elif qarr.qtype == Qtypes.bra: 

826 return qarr.dag() 

827 else: 

828 raise ValueError("Can only get ket from a ket or bra.") 

829 

830 

831def eigenstates(qarr: Qarray) -> Qarray: 

832 """Eigenstates of a quantum array. 

833 

834 Args: 

835 qarr (Qarray): quantum array 

836 

837 Returns: 

838 eigenvalues and eigenstates 

839 """ 

840 

841 evals, evecs = jnp.linalg.eigh(qarr.data) 

842 idxs_sorted = jnp.argsort(evals, axis=-1) 

843 

844 dims = ket_from_op_dims(qarr.dims) 

845 

846 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1) 

847 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1) 

848 

849 evecs = Qarray.create( 

850 evecs, 

851 dims=dims, 

852 bdims=evecs.shape[:-1], 

853 ) 

854 

855 return evals, evecs 

856 

857 

858def eigenenergies(qarr: Qarray) -> Array: 

859 """Eigenvalues of a quantum array. 

860 

861 Args: 

862 qarr (Qarray): quantum array 

863 

864 Returns: 

865 eigenvalues 

866 """ 

867 

868 evals = jnp.linalg.eigvalsh(qarr.data) 

869 return evals 

870 

871 

872# More quantum specific ----------------------------------------------------- 

873 

874 

875def ptrace(qarr: Qarray, indx) -> Qarray: 

876 """Partial Trace. 

877 

878 Args: 

879 rho: density matrix 

880 indx: index of quantum object to keep, rest will be partial traced out 

881 

882 Returns: 

883 partial traced out density matrix 

884 

885 TODO: Fix weird tracing errors that arise with reshape 

886 """ 

887 

888 qarr = ket2dm(qarr) 

889 rho = qarr.shaped_data 

890 dims = qarr.dims 

891 

892 Nq = len(dims[0]) 

893 

894 indxs = [indx, indx + Nq] 

895 for j in range(Nq): 

896 if j == indx: 

897 continue 

898 indxs.append(j) 

899 indxs.append(j + Nq) 

900 

901 bdims = qarr.bdims 

902 len_bdims = len(bdims) 

903 bdims_indxs = list(range(len_bdims)) 

904 indxs = bdims_indxs + [j + len_bdims for j in indxs] 

905 rho = rho.transpose(indxs) 

906 

907 for j in range(Nq - 1): 

908 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims) 

909 

910 return Qarray.create(rho) 

911 

912 

913def dag(qarr: Qarray) -> Qarray: 

914 """Conjugate transpose. 

915 

916 Args: 

917 qarr (Qarray): quantum array 

918 

919 Returns: 

920 conjugate transpose of qarr 

921 """ 

922 dims = qarr.dims[::-1] 

923 

924 data = dag_data(qarr.data) 

925 

926 return Qarray.create(data, dims=dims) 

927 

928 

929def dag_data(arr: Array) -> Array: 

930 """Conjugate transpose. 

931 

932 Args: 

933 arr: operator 

934 

935 Returns: 

936 conjugate of op, and transposes last two axes 

937 """ 

938 # TODO: revisit this case... 

939 if len(arr.shape) == 1: 

940 return jnp.conj(arr) 

941 

942 return jnp.moveaxis( 

943 jnp.conj(arr), -1, -2 

944 ) # transposes last two axes, good for batching 

945 

946 

947def ket2dm(qarr: Qarray) -> Qarray: 

948 """Turns ket into density matrix. 

949 Does nothing if already operator. 

950 

951 Args: 

952 qarr (Qarray): qarr 

953 

954 Returns: 

955 Density matrix 

956 """ 

957 

958 if qarr.qtype == Qtypes.oper: 

959 return qarr 

960 

961 if qarr.qtype == Qtypes.bra: 

962 qarr = qarr.dag() 

963 

964 return qarr @ qarr.dag() 

965 

966 

967# Data level operations ---- 

968 

969 

970def is_dm_data(data: Array) -> bool: 

971 """Check if data is a density matrix. 

972 

973 Args: 

974 data: matrix 

975 Returns: 

976 True if data is a density matrix 

977 """ 

978 return data.shape[-2] == data.shape[-1] 

979 

980 

981def powm_data(data: Array, n: int) -> Array: 

982 """Matrix power. 

983 

984 Args: 

985 data: matrix 

986 n: power 

987 

988 Returns: 

989 matrix power 

990 """ 

991 return jnp.linalg.matrix_power(data, n)