Coverage for jaxquantum/core/qarray.py: 76%

461 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 19:55 +0000

1"""QArray.""" 

2 

3from __future__ import annotations 

4 

5from flax import struct 

6from jax import Array, config, vmap 

7from typing import List, Union 

8 

9 

10from math import prod 

11from copy import deepcopy 

12from numpy import ndarray 

13import jax.numpy as jnp 

14import jax.scipy as jsp 

15 

16from jaxquantum.core.settings import SETTINGS 

17from jaxquantum.utils.utils import robust_isscalar 

18from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims 

19 

20config.update("jax_enable_x64", True) 

21 

22 

23def tidy_up(data, atol): 

24 data_re = jnp.real(data) 

25 data_im = jnp.imag(data) 

26 data_re_mask = jnp.abs(data_re) > atol 

27 data_im_mask = jnp.abs(data_im) > atol 

28 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask 

29 return data_new 

30 

31 

32@struct.dataclass # this allows us to send in and return Qarray from jitted functions 

33class Qarray: 

34 _data: Array 

35 _qdims: Qdims = struct.field(pytree_node=False) 

36 _bdims: tuple[int] = struct.field(pytree_node=False) 

37 

38 # Initialization ---- 

39 @classmethod 

40 def create(cls, data, dims=None, bdims=None): 

41 # Step 1: Prepare data ---- 

42 data = jnp.asarray(data) 

43 

44 if len(data.shape) == 1 and data.shape[0] > 0: 

45 data = data.reshape(data.shape[0], 1) 

46 

47 if len(data.shape) >= 2: 

48 if data.shape[-2] != data.shape[-1] and not ( 

49 data.shape[-2] == 1 or data.shape[-1] == 1 

50 ): 

51 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

52 

53 if bdims is not None: 

54 if len(data.shape) - len(bdims) == 1: 

55 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

56 # ---- 

57 

58 # Step 2: Prepare dimensions ---- 

59 if bdims is None: 

60 bdims = tuple(data.shape[:-2]) 

61 

62 if dims is None: 

63 dims = ((data.shape[-2],), (data.shape[-1],)) 

64 

65 if not isinstance(dims[0], (list, tuple)): 

66 # This handles the case where only the hilbert space dimensions are sent in. 

67 if data.shape[-1] == 1: 

68 dims = (tuple(dims), tuple([1 for _ in dims])) 

69 elif data.shape[-2] == 1: 

70 dims = (tuple([1 for _ in dims]), tuple(dims)) 

71 else: 

72 dims = (tuple(dims), tuple(dims)) 

73 else: 

74 dims = (tuple(dims[0]), tuple(dims[1])) 

75 

76 check_dims(dims, bdims, data.shape) 

77 

78 qdims = Qdims(dims) 

79 

80 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill. 

81 # It increases the compilation time, but only very slightly 

82 # increased the runtime of the jit compiled function. 

83 # We could instead use this tidy_up where we think we need it. 

84 data = tidy_up(data, SETTINGS["auto_tidyup_atol"]) 

85 

86 return cls(data, qdims, bdims) 

87 

88 # ---- 

89 

90 @classmethod 

91 def from_list(cls, qarr_list: List[Qarray]) -> Qarray: 

92 """Create a Qarray from a list of Qarrays.""" 

93 

94 data = jnp.array([qarr.data for qarr in qarr_list]) 

95 

96 if len(qarr_list) == 0: 

97 dims = ((), ()) 

98 bdims = () 

99 else: 

100 dims = qarr_list[0].dims 

101 bdims = qarr_list[0].bdims 

102 

103 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list): 

104 raise ValueError("All Qarrays in the list must have the same dimensions.") 

105 

106 bdims = (len(qarr_list),) + bdims 

107 

108 return cls.create(data, dims=dims, bdims=bdims) 

109 

110 @classmethod 

111 def from_array(cls, qarr_arr) -> Qarray: 

112 """Create a Qarray from a nested list of Qarrays. 

113 

114 Args: 

115 qarr_arr (list): nested list of Qarrays 

116 

117 Returns: 

118 Qarray: Qarray object 

119 """ 

120 if isinstance(qarr_arr, Qarray): 

121 return qarr_arr 

122 

123 bdims = () 

124 lvl = qarr_arr 

125 while not isinstance(lvl, Qarray): 

126 bdims = bdims + (len(lvl),) 

127 if len(lvl) > 0: 

128 lvl = lvl[0] 

129 else: 

130 break 

131 

132 def flat(lis): 

133 flatList = [] 

134 for element in lis: 

135 if type(element) is list: 

136 flatList += flat(element) 

137 else: 

138 flatList.append(element) 

139 return flatList 

140 

141 qarr_list = flat(qarr_arr) 

142 qarr = cls.from_list(qarr_list) 

143 qarr = qarr.reshape_bdims(*bdims) 

144 return qarr 

145 

146 # Properties ---- 

147 @property 

148 def qtype(self): 

149 return self._qdims.qtype 

150 

151 @property 

152 def dtype(self): 

153 return self._data.dtype 

154 

155 @property 

156 def dims(self): 

157 return self._qdims.dims 

158 

159 @property 

160 def bdims(self): 

161 return self._bdims 

162 

163 @property 

164 def qdims(self): 

165 return self._qdims 

166 

167 @property 

168 def space_dims(self): 

169 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

170 return self.dims[0] 

171 elif self.qtype == Qtypes.bra: 

172 return self.dims[1] 

173 else: 

174 # TODO: not reached for some reason 

175 raise ValueError("Unsupported qtype.") 

176 

177 @property 

178 def data(self): 

179 return self._data 

180 

181 @property 

182 def shaped_data(self): 

183 return self._data.reshape(self.bdims + self.dims[0] + self.dims[1]) 

184 

185 @property 

186 def shape(self): 

187 return self.data.shape 

188 

189 @property 

190 def is_batched(self): 

191 return len(self.bdims) > 0 

192 

193 def __getitem__(self, index): 

194 if len(self.bdims) > 0: 

195 return Qarray.create( 

196 self.data[index], 

197 dims=self.dims, 

198 ) 

199 else: 

200 raise ValueError("Cannot index a non-batched Qarray.") 

201 

202 def reshape_bdims(self, *args): 

203 """Reshape the batch dimensions of the Qarray.""" 

204 new_bdims = tuple(args) 

205 

206 if prod(new_bdims) == 0: 

207 new_shape = new_bdims 

208 else: 

209 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,) 

210 return Qarray.create( 

211 self.data.reshape(new_shape), 

212 dims=self.dims, 

213 bdims=new_bdims, 

214 ) 

215 

216 def space_to_qdims(self, space_dims: List[int]): 

217 if isinstance(space_dims[0], (list, tuple)): 

218 return space_dims 

219 

220 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

221 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))])) 

222 elif self.qtype == Qtypes.bra: 

223 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims)) 

224 else: 

225 raise ValueError("Unsupported qtype for space_to_qdims conversion.") 

226 

227 def reshape_qdims(self, *args): 

228 """Reshape the quantum dimensions of the Qarray. 

229 

230 Note that this does not take in qdims but rather the new Hilbert space dimensions. 

231 

232 Args: 

233 *args: new Hilbert dimensions for the Qarray. 

234 

235 Returns: 

236 Qarray: reshaped Qarray. 

237 """ 

238 

239 new_space_dims = tuple(args) 

240 current_space_dims = self.space_dims 

241 assert prod(new_space_dims) == prod(current_space_dims) 

242 

243 new_qdims = self.space_to_qdims(new_space_dims) 

244 new_bdims = self.bdims 

245 

246 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims) 

247 

248 def resize(self, new_shape): 

249 """Resize the Qarray to a new shape. 

250 

251 TODO: review and maybe deprecate this method. 

252 """ 

253 dims = self.dims 

254 data = jnp.resize(self.data, new_shape) 

255 return Qarray.create( 

256 data, 

257 dims=dims, 

258 ) 

259 

260 def __len__(self): 

261 """Length of the Qarray.""" 

262 if len(self.bdims) > 0: 

263 return self.data.shape[0] 

264 else: 

265 raise ValueError("Cannot get length of a non-batched Qarray.") 

266 

267 def __eq__(self, other): 

268 if not isinstance(other, Qarray): 

269 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.") 

270 

271 if self.dims != other.dims: 

272 return False 

273 

274 if self.bdims != other.bdims: 

275 return False 

276 

277 return jnp.all(self.data == other.data) 

278 

279 def __ne__(self, other): 

280 return not self.__eq__(other) 

281 

282 # ---- 

283 

284 # Elementary Math ---- 

285 def __matmul__(self, other): 

286 if not isinstance(other, Qarray): 

287 return NotImplemented 

288 _qdims_new = self._qdims @ other._qdims 

289 return Qarray.create( 

290 self.data @ other.data, 

291 dims=_qdims_new.dims, 

292 ) 

293 

294 # NOTE: not possible to reach this. 

295 # def __rmatmul__(self, other): 

296 # if not isinstance(other, Qarray): 

297 # return NotImplemented 

298 

299 # _qdims_new = other._qdims @ self._qdims 

300 # return Qarray.create( 

301 # other.data @ self.data, 

302 # dims=_qdims_new.dims, 

303 # ) 

304 

305 def __mul__(self, other): 

306 if isinstance(other, Qarray): 

307 return self.__matmul__(other) 

308 

309 other = other + 0.0j 

310 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

311 other = other.reshape(other.shape + (1, 1)) 

312 

313 return Qarray.create( 

314 other * self.data, 

315 dims=self._qdims.dims, 

316 ) 

317 

318 def __rmul__(self, other): 

319 # NOTE: not possible to reach this. 

320 # if isinstance(other, Qarray): 

321 # return self.__rmatmul__(other) 

322 

323 return self.__mul__(other) 

324 

325 def __neg__(self): 

326 return self.__mul__(-1) 

327 

328 def __truediv__(self, other): 

329 """For Qarray's, this only really makes sense in the context of division by a scalar.""" 

330 

331 if isinstance(other, Qarray): 

332 raise ValueError("Cannot divide a Qarray by another Qarray.") 

333 

334 return self.__mul__(1 / other) 

335 

336 def __add__(self, other): 

337 if isinstance(other, Qarray): 

338 if self.dims != other.dims: 

339 msg = ( 

340 "Dimensions are incompatible: " 

341 + repr(self.dims) 

342 + " and " 

343 + repr(other.dims) 

344 ) 

345 raise ValueError(msg) 

346 return Qarray.create(self.data + other.data, dims=self.dims) 

347 

348 if robust_isscalar(other) and other == 0: 

349 return self.copy() 

350 

351 if self.data.shape[-2] == self.data.shape[-1]: 

352 other = other + 0.0j 

353 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

354 other = other.reshape(other.shape + (1, 1)) 

355 other = Qarray.create( 

356 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype), 

357 dims=self.dims, 

358 ) 

359 return self.__add__(other) 

360 

361 return NotImplemented 

362 

363 def __radd__(self, other): 

364 return self.__add__(other) 

365 

366 def __sub__(self, other): 

367 if isinstance(other, Qarray): 

368 if self.dims != other.dims: 

369 msg = ( 

370 "Dimensions are incompatible: " 

371 + repr(self.dims) 

372 + " and " 

373 + repr(other.dims) 

374 ) 

375 raise ValueError(msg) 

376 return Qarray.create(self.data - other.data, dims=self.dims) 

377 

378 if robust_isscalar(other) and other == 0: 

379 return self.copy() 

380 

381 if self.data.shape[-2] == self.data.shape[-1]: 

382 other = other + 0.0j 

383 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

384 other = other.reshape(other.shape + (1, 1)) 

385 other = Qarray.create( 

386 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype), 

387 dims=self.dims, 

388 ) 

389 return self.__sub__(other) 

390 

391 return NotImplemented 

392 

393 def __rsub__(self, other): 

394 return self.__neg__().__add__(other) 

395 

396 def __xor__(self, other): 

397 if not isinstance(other, Qarray): 

398 return NotImplemented 

399 return tensor(self, other) 

400 

401 def __rxor__(self, other): 

402 if not isinstance(other, Qarray): 

403 return NotImplemented 

404 return tensor(other, self) 

405 

406 def __pow__(self, other): 

407 if not isinstance(other, int): 

408 return NotImplemented 

409 

410 return powm(self, other) 

411 

412 # ---- 

413 

414 # String Representation ---- 

415 def _str_header(self): 

416 out = ", ".join( 

417 [ 

418 "Quantum array: dims = " + str(self.dims), 

419 "bdims = " + str(self.bdims), 

420 "shape = " + str(self._data.shape), 

421 "type = " + str(self.qtype), 

422 ] 

423 ) 

424 return out 

425 

426 def __str__(self): 

427 return self._str_header() + "\nQarray data =\n" + str(self.data) 

428 

429 @property 

430 def header(self): 

431 """Print the header of the Qarray.""" 

432 return self._str_header() 

433 

434 def __repr__(self): 

435 return self.__str__() 

436 

437 # ---- 

438 

439 # Utilities ---- 

440 def copy(self, memo=None): 

441 # return Qarray.create(deepcopy(self.data), dims=self.dims) 

442 return self.__deepcopy__(memo) 

443 

444 def __deepcopy__(self, memo): 

445 """Need to override this when defininig __getattr__.""" 

446 

447 return Qarray( 

448 _data=deepcopy(self._data, memo=memo), 

449 _qdims=deepcopy(self._qdims, memo=memo), 

450 _bdims=deepcopy(self._bdims, memo=memo), 

451 ) 

452 

453 def __getattr__(self, method_name): 

454 if "__" == method_name[:2]: 

455 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__ 

456 return lambda *args, **kwargs: NotImplemented 

457 

458 modules = [jnp, jnp.linalg, jsp, jsp.linalg] 

459 

460 method_f = None 

461 for mod in modules: 

462 method_f = getattr(mod, method_name, None) 

463 if method_f is not None: 

464 break 

465 

466 if method_f is None: 

467 raise NotImplementedError( 

468 f"Method {method_name} does not exist. No backup method found in {modules}." 

469 ) 

470 

471 def func(*args, **kwargs): 

472 res = method_f(self.data, *args, **kwargs) 

473 

474 if getattr(res, "shape", None) is None or res.shape != self.data.shape: 

475 return res 

476 else: 

477 return Qarray.create(res, dims=self._qdims.dims) 

478 

479 return func 

480 

481 # ---- 

482 

483 # Conversions / Reshaping ---- 

484 def dag(self): 

485 return dag(self) 

486 

487 def to_dm(self): 

488 return ket2dm(self) 

489 

490 def is_dm(self): 

491 return self.qtype == Qtypes.oper 

492 

493 def is_vec(self): 

494 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra 

495 

496 def to_ket(self): 

497 return to_ket(self) 

498 

499 def transpose(self, *args): 

500 return transpose(self, *args) 

501 

502 def keep_only_diag_elements(self): 

503 return keep_only_diag_elements(self) 

504 

505 # ---- 

506 

507 # Math Functions ---- 

508 def unit(self): 

509 return unit(self) 

510 

511 def norm(self): 

512 return norm(self) 

513 

514 def expm(self): 

515 return expm(self) 

516 

517 def powm(self, n): 

518 return powm(self, n) 

519 

520 def cosm(self): 

521 return cosm(self) 

522 

523 def sinm(self): 

524 return sinm(self) 

525 

526 def tr(self, **kwargs): 

527 return tr(self, **kwargs) 

528 

529 def trace(self, **kwargs): 

530 return tr(self, **kwargs) 

531 

532 def ptrace(self, indx): 

533 return ptrace(self, indx) 

534 

535 def eigenstates(self): 

536 return eigenstates(self) 

537 

538 def eigenenergies(self): 

539 return eigenenergies(self) 

540 

541 def collapse(self, mode="sum"): 

542 return collapse(self, mode=mode) 

543 

544 # ---- 

545 

546 

547ARRAY_TYPES = (Array, ndarray, Qarray) 

548 

549# Qarray operations --------------------------------------------------------------------- 

550 

551def concatenate(qarr_list: List[Qarray], axis: int = 0) -> Qarray: 

552 """Concatenate a list of Qarrays along a specified axis. 

553 

554 Args: 

555 qarr_list (List[Qarray]): List of Qarrays to concatenate. 

556 axis (int): Axis along which to concatenate. Default is 0. 

557 

558 Returns: 

559 Qarray: Concatenated Qarray. 

560 """ 

561 

562 non_empty_qarr_list = [qarr for qarr in qarr_list if len(qarr.data) != 0] 

563 

564 if len(non_empty_qarr_list) == 0: 

565 return Qarray.from_list([]) 

566 

567 concatenated_data = jnp.concatenate( 

568 [qarr.data for qarr in non_empty_qarr_list], axis=axis 

569 ) 

570 

571 dims = non_empty_qarr_list[0].dims 

572 return Qarray.create(concatenated_data, dims=dims) 

573 

574 

575def collapse(qarr: Qarray, mode="sum") -> Qarray: 

576 """Collapse the Qarray. 

577 

578 Args: 

579 qarr (Qarray): quantum array array 

580 

581 Returns: 

582 Collapsed quantum array 

583 """ 

584 if mode == "sum": 

585 if len(qarr.bdims) == 0: 

586 return qarr 

587 

588 batch_axes = list(range(len(qarr.bdims))) 

589 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims) 

590 

591 

592def transpose(qarr: Qarray, indices: List[int]) -> Qarray: 

593 """Transpose the quantum array. 

594 

595 Args: 

596 qarr (Qarray): quantum array 

597 *args: axes to transpose 

598 

599 Returns: 

600 tranposed Qarray 

601 """ 

602 

603 indices = list(indices) 

604 

605 shaped_data = qarr.shaped_data 

606 dims = qarr.dims 

607 bdims_indxs = list(range(len(qarr.bdims))) 

608 

609 reshape_indices = indices + [j + len(dims[0]) for j in indices] 

610 

611 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices] 

612 

613 shaped_data = shaped_data.transpose(reshape_indices) 

614 new_dims = ( 

615 tuple([dims[0][j] for j in indices]), 

616 tuple([dims[1][j] for j in indices]), 

617 ) 

618 

619 full_dims = prod(dims[0]) 

620 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1) 

621 return Qarray.create(full_data, dims=new_dims) 

622 

623 

624def unit(qarr: Qarray) -> Qarray: 

625 """Normalize the quantum array. 

626 

627 Args: 

628 qarr (Qarray): quantum array 

629 

630 Returns: 

631 Normalized quantum array 

632 """ 

633 return qarr / qarr.norm() 

634 

635 

636def norm(qarr: Qarray) -> float: 

637 qdata = qarr.data 

638 bdims = qarr.bdims 

639 

640 if qarr.qtype == Qtypes.oper: 

641 qdata_dag = qarr.dag().data 

642 

643 if len(bdims) > 0: 

644 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1]) 

645 qdata_dag = qdata_dag.reshape(-1, qdata_dag.shape[-2], qdata_dag.shape[-1]) 

646 

647 evals, _ = vmap(jnp.linalg.eigh)(qdata @ qdata_dag) 

648 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)), axis=-1) 

649 rho_norm = rho_norm.reshape(*bdims) 

650 return rho_norm 

651 else: 

652 evals, _ = jnp.linalg.eigh(qdata @ qdata_dag) 

653 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals))) 

654 return rho_norm 

655 

656 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]: 

657 if len(bdims) > 0: 

658 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1]) 

659 return vmap(jnp.linalg.norm)(qdata).reshape(*bdims) 

660 else: 

661 return jnp.linalg.norm(qdata) 

662 

663 

664def tensor(*args, **kwargs) -> Qarray: 

665 """Tensor product. 

666 

667 Args: 

668 *args (Qarray): tensors to take the product of 

669 parallel (bool): if True, use parallel einsum for tensor product 

670 true: [A,B] ^ [C,D] = [A^C, B^D] 

671 false (default): [A,B] ^ [C,D] = [A^C, A^D, B^C, B^D] 

672 

673 Returns: 

674 Tensor product of given tensors 

675 

676 """ 

677 

678 parallel = kwargs.pop("parallel", False) 

679 

680 data = args[0].data 

681 dims = deepcopy(args[0].dims) 

682 dims_0 = dims[0] 

683 dims_1 = dims[1] 

684 for arg in args[1:]: 

685 if parallel: 

686 a = data 

687 b = arg.data 

688 

689 if len(a.shape) > len(b.shape): 

690 batch_dim = a.shape[:-2] 

691 elif len(a.shape) == len(b.shape): 

692 if prod(a.shape[:-2]) > prod(b.shape[:-2]): 

693 batch_dim = a.shape[:-2] 

694 else: 

695 batch_dim = b.shape[:-2] 

696 else: 

697 batch_dim = b.shape[:-2] 

698 

699 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape( 

700 *batch_dim, a.shape[-2] * b.shape[-2], -1 

701 ) 

702 else: 

703 data = jnp.kron(data, arg.data, **kwargs) 

704 

705 dims_0 = dims_0 + arg.dims[0] 

706 dims_1 = dims_1 + arg.dims[1] 

707 

708 return Qarray.create(data, dims=(dims_0, dims_1)) 

709 

710 

711def tr(qarr: Qarray, **kwargs) -> Array: 

712 """Full trace. 

713 

714 Args: 

715 qarr (Qarray): quantum array 

716 

717 Returns: 

718 Full trace. 

719 """ 

720 axis1 = kwargs.get("axis1", -2) 

721 axis2 = kwargs.get("axis2", -1) 

722 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs) 

723 

724 

725def trace(qarr: Qarray, **kwargs) -> Array: 

726 """Full trace. 

727 

728 Args: 

729 qarr (Qarray): quantum array 

730 

731 Returns: 

732 Full trace. 

733 """ 

734 return tr(qarr, **kwargs) 

735 

736 

737def expm_data(data: Array, **kwargs) -> Array: 

738 """Matrix exponential wrapper. 

739 

740 Returns: 

741 matrix exponential 

742 """ 

743 return jsp.linalg.expm(data, **kwargs) 

744 

745 

746def expm(qarr: Qarray, **kwargs) -> Qarray: 

747 """Matrix exponential wrapper. 

748 

749 Returns: 

750 matrix exponential 

751 """ 

752 dims = qarr.dims 

753 data = expm_data(qarr.data, **kwargs) 

754 return Qarray.create(data, dims=dims) 

755 

756 

757def powm(qarr: Qarray, n: Union[int, float], clip_eigvals=False) -> Qarray: 

758 """Matrix power. 

759 

760 Args: 

761 qarr (Qarray): quantum array 

762 n (int): power 

763 clip_eigvals (bool): clip eigenvalues to always be able to compute 

764 non-integer powers 

765 

766 Returns: 

767 matrix power 

768 """ 

769 if isinstance(n, int): 

770 data_res = jnp.linalg.matrix_power(qarr.data, n) 

771 else: 

772 evalues, evectors = jnp.linalg.eig(qarr.data) 

773 if clip_eigvals: 

774 evalues = jnp.maximum(evalues, 0) 

775 else: 

776 if not (evalues >= 0).all(): 

777 raise ValueError( 

778 "Non-integer power of a matrix can only be " 

779 "computed if the matrix is positive semi-definite." 

780 "Got a matrix with a negative eigenvalue." 

781 ) 

782 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors) 

783 return Qarray.create(data_res, dims=qarr.dims) 

784 

785 

786def cosm_data(data: Array, **kwargs) -> Array: 

787 """Matrix cosine wrapper. 

788 

789 Returns: 

790 matrix cosine 

791 """ 

792 return (expm_data(1j * data) + expm_data(-1j * data)) / 2 

793 

794 

795def cosm(qarr: Qarray) -> Qarray: 

796 """Matrix cosine wrapper. 

797 

798 Args: 

799 qarr (Qarray): quantum array 

800 

801 Returns: 

802 matrix cosine 

803 """ 

804 dims = qarr.dims 

805 data = cosm_data(qarr.data) 

806 return Qarray.create(data, dims=dims) 

807 

808 

809def sinm_data(data: Array, **kwargs) -> Array: 

810 """Matrix sine wrapper. 

811 

812 Args: 

813 data: matrix 

814 

815 Returns: 

816 matrix sine 

817 """ 

818 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j) 

819 

820 

821def sinm(qarr: Qarray) -> Qarray: 

822 dims = qarr.dims 

823 data = sinm_data(qarr.data) 

824 return Qarray.create(data, dims=dims) 

825 

826 

827def keep_only_diag_elements(qarr: Qarray) -> Qarray: 

828 if len(qarr.bdims) > 0: 

829 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.") 

830 

831 dims = qarr.dims 

832 data = jnp.diag(jnp.diag(qarr.data)) 

833 return Qarray.create(data, dims=dims) 

834 

835 

836def to_ket(qarr: Qarray) -> Qarray: 

837 if qarr.qtype == Qtypes.ket: 

838 return qarr 

839 elif qarr.qtype == Qtypes.bra: 

840 return qarr.dag() 

841 else: 

842 raise ValueError("Can only get ket from a ket or bra.") 

843 

844 

845def eigenstates(qarr: Qarray) -> Qarray: 

846 """Eigenstates of a quantum array. 

847 

848 Args: 

849 qarr (Qarray): quantum array 

850 

851 Returns: 

852 eigenvalues and eigenstates 

853 """ 

854 

855 evals, evecs = jnp.linalg.eigh(qarr.data) 

856 idxs_sorted = jnp.argsort(evals, axis=-1) 

857 

858 dims = ket_from_op_dims(qarr.dims) 

859 

860 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1) 

861 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1) 

862 

863 evecs = Qarray.create( 

864 evecs, 

865 dims=dims, 

866 bdims=evecs.shape[:-1], 

867 ) 

868 

869 return evals, evecs 

870 

871 

872def eigenenergies(qarr: Qarray) -> Array: 

873 """Eigenvalues of a quantum array. 

874 

875 Args: 

876 qarr (Qarray): quantum array 

877 

878 Returns: 

879 eigenvalues 

880 """ 

881 

882 evals = jnp.linalg.eigvalsh(qarr.data) 

883 return evals 

884 

885 

886# More quantum specific ----------------------------------------------------- 

887 

888 

889def ptrace(qarr: Qarray, indx) -> Qarray: 

890 """Partial Trace. 

891 

892 Args: 

893 rho: density matrix 

894 indx: index of quantum object to keep, rest will be partial traced out 

895 

896 Returns: 

897 partial traced out density matrix 

898 

899 TODO: Fix weird tracing errors that arise with reshape 

900 """ 

901 

902 qarr = ket2dm(qarr) 

903 rho = qarr.shaped_data 

904 dims = qarr.dims 

905 

906 Nq = len(dims[0]) 

907 

908 indxs = [indx, indx + Nq] 

909 for j in range(Nq): 

910 if j == indx: 

911 continue 

912 indxs.append(j) 

913 indxs.append(j + Nq) 

914 

915 bdims = qarr.bdims 

916 len_bdims = len(bdims) 

917 bdims_indxs = list(range(len_bdims)) 

918 indxs = bdims_indxs + [j + len_bdims for j in indxs] 

919 rho = rho.transpose(indxs) 

920 

921 for j in range(Nq - 1): 

922 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims) 

923 

924 return Qarray.create(rho) 

925 

926 

927def dag(qarr: Qarray) -> Qarray: 

928 """Conjugate transpose. 

929 

930 Args: 

931 qarr (Qarray): quantum array 

932 

933 Returns: 

934 conjugate transpose of qarr 

935 """ 

936 dims = qarr.dims[::-1] 

937 

938 data = dag_data(qarr.data) 

939 

940 return Qarray.create(data, dims=dims) 

941 

942 

943def dag_data(arr: Array) -> Array: 

944 """Conjugate transpose. 

945 

946 Args: 

947 arr: operator 

948 

949 Returns: 

950 conjugate of op, and transposes last two axes 

951 """ 

952 # TODO: revisit this case... 

953 if len(arr.shape) == 1: 

954 return jnp.conj(arr) 

955 

956 return jnp.moveaxis( 

957 jnp.conj(arr), -1, -2 

958 ) # transposes last two axes, good for batching 

959 

960 

961def ket2dm(qarr: Qarray) -> Qarray: 

962 """Turns ket into density matrix. 

963 Does nothing if already operator. 

964 

965 Args: 

966 qarr (Qarray): qarr 

967 

968 Returns: 

969 Density matrix 

970 """ 

971 

972 if qarr.qtype == Qtypes.oper: 

973 return qarr 

974 

975 if qarr.qtype == Qtypes.bra: 

976 qarr = qarr.dag() 

977 

978 return qarr @ qarr.dag() 

979 

980 

981# Data level operations ---- 

982 

983 

984def is_dm_data(data: Array) -> bool: 

985 """Check if data is a density matrix. 

986 

987 Args: 

988 data: matrix 

989 Returns: 

990 True if data is a density matrix 

991 """ 

992 return data.shape[-2] == data.shape[-1] 

993 

994 

995def powm_data(data: Array, n: int) -> Array: 

996 """Matrix power. 

997 

998 Args: 

999 data: matrix 

1000 n: power 

1001 

1002 Returns: 

1003 matrix power 

1004 """ 

1005 return jnp.linalg.matrix_power(data, n)