Coverage for jaxquantum / core / qarray.py: 81%

774 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""New Qarray implementation with sparse support.""" 

2 

3from __future__ import annotations 

4 

5from abc import ABC, abstractmethod 

6from flax import struct 

7from jax import Array, config, vmap 

8from typing import TYPE_CHECKING, List, Union, TypeVar, Generic, overload, Literal 

9 

10if TYPE_CHECKING: 

11 from jaxquantum.core.sparse_bcoo import SparseBCOOImpl 

12import jax.numpy as jnp 

13import jax.scipy as jsp 

14from jax.experimental import sparse 

15from numpy import ndarray 

16from copy import deepcopy 

17from math import prod 

18from enum import Enum 

19 

20from jaxquantum.core.settings import SETTINGS 

21from jaxquantum.utils.utils import robust_isscalar 

22from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims 

23 

24config.update("jax_enable_x64", True) 

25 

26# Type variable for implementation types 

27ImplT = TypeVar("ImplT", bound="QarrayImpl") 

28 

29# Module-level registry mapping impl_class -> QarrayImplType member 

30_IMPL_REGISTRY: dict = {} 

31 

32 

33class QarrayImplType(Enum): 

34 """Enumeration of available Qarray storage backends. 

35 

36 Each member maps one-to-one with a concrete ``QarrayImpl`` subclass. 

37 New backends should call ``QarrayImplType.register(MyImpl, QarrayImplType.MY_TYPE)`` 

38 immediately after defining their impl class. 

39 

40 Members: 

41 DENSE: Standard JAX dense array (``jnp.ndarray``). 

42 SPARSE_BCOO: JAX experimental BCOO sparse array. 

43 SPARSE_DIA: Diagonal sparse array. 

44 """ 

45 

46 DENSE = "dense" 

47 SPARSE_BCOO = "sparse_bcoo" 

48 SPARSE_DIA = "sparse_dia" 

49 

50 @classmethod 

51 def register(cls, impl_class, member): 

52 """Register an implementation class with a QarrayImplType member. 

53 

54 Args: 

55 impl_class: The concrete ``QarrayImpl`` subclass to register. 

56 member: The ``QarrayImplType`` enum member to associate with it. 

57 """ 

58 _IMPL_REGISTRY[impl_class] = member 

59 

60 @classmethod 

61 def has(cls, x) -> bool: 

62 """Return True if x corresponds to a member of QarrayImplType. 

63 

64 Accepts an existing ``QarrayImplType`` member, a string equal to the 

65 member name or value (case-insensitive), or an implementation class 

66 (e.g. ``DenseImpl``, ``SparseBCOOImpl``) that has been registered. 

67 

68 Args: 

69 x: Value to test — a ``QarrayImplType``, ``str``, or impl class. 

70 

71 Returns: 

72 True if ``x`` maps to a known ``QarrayImplType`` member. 

73 """ 

74 if isinstance(x, cls): 

75 return True 

76 

77 if isinstance(x, str): 

78 xl = x.lower() 

79 return any(xl == member.value or xl == member.name.lower() for member in cls) 

80 

81 # Try mapping from an implementation class to an enum member 

82 try: 

83 cls.from_impl_class(x) 

84 return True 

85 except Exception: 

86 return False 

87 

88 @classmethod 

89 def from_impl_class(cls, impl_class) -> "QarrayImplType": 

90 """Return the ``QarrayImplType`` member associated with *impl_class*. 

91 

92 Args: 

93 impl_class: A concrete ``QarrayImpl`` subclass that has been 

94 registered via :meth:`register`. 

95 

96 Returns: 

97 The corresponding ``QarrayImplType`` member. 

98 

99 Raises: 

100 ValueError: If *impl_class* is not in the registry. 

101 """ 

102 if impl_class in _IMPL_REGISTRY: 

103 return _IMPL_REGISTRY[impl_class] 

104 raise ValueError(f"Unknown implementation class: {impl_class}") 

105 

106 def get_impl_class(self): 

107 """Return the implementation class registered for this member. 

108 

109 Returns: 

110 The concrete ``QarrayImpl`` subclass associated with this member. 

111 

112 Raises: 

113 ValueError: If no class has been registered for this member. 

114 """ 

115 for cls_key, member in _IMPL_REGISTRY.items(): 

116 if member is self: 

117 return cls_key 

118 raise ValueError(f"No impl class registered for {self}") 

119 

120 

121def robust_asarray(data) -> Union[Array, sparse.BCOO]: 

122 """Convert *data* to a JAX array, leaving sparse BCOO and SparseDiaData untouched. 

123 

124 Args: 

125 data: Input data — any array-like, ``sparse.BCOO``, or ``SparseDiaData``. 

126 

127 Returns: 

128 A ``jax.Array``, ``sparse.BCOO``, or ``SparseDiaData``. 

129 """ 

130 if isinstance(data, sparse.BCOO): 

131 return data 

132 # SparseDiaData has a ``_is_sparse_dia`` marker; pass it through unchanged 

133 if getattr(data, "_is_sparse_dia", False): 

134 return data 

135 return jnp.asarray(data) 

136 

137 

138class QarrayImpl(ABC): 

139 """Abstract base class defining the interface every storage backend must implement. 

140 

141 A ``QarrayImpl`` wraps a raw data array (dense ``jnp.ndarray`` or sparse 

142 ``BCOO``) and provides the mathematical primitives used by ``Qarray``. 

143 Concrete subclasses must implement every ``@abstractmethod``. 

144 

145 Attributes: 

146 PROMOTION_ORDER: Integer priority used by ``_coerce`` to decide which 

147 side to promote when operands have different types. Higher means 

148 "more general" (``DenseImpl = 1``, ``SparseBCOOImpl = 0``). 

149 """ 

150 

151 PROMOTION_ORDER: int = 0 # override in subclasses; higher = more general 

152 # Current hierarchy: SparseDiaImpl=0, SparseBCOOImpl=1, DenseImpl=2 

153 

154 @abstractmethod 

155 def get_data(self) -> Array: 

156 """Return the underlying raw data array.""" 

157 pass 

158 

159 @property 

160 def data(self) -> Array: 

161 """The underlying raw data array.""" 

162 return self.get_data() 

163 

164 @property 

165 def impl_type(self) -> QarrayImplType: 

166 """The ``QarrayImplType`` member corresponding to this instance.""" 

167 return QarrayImplType.from_impl_class(type(self)) 

168 

169 @classmethod 

170 @abstractmethod 

171 def from_data(cls, data) -> "QarrayImpl": 

172 """Wrap raw data in this impl type. 

173 

174 Args: 

175 data: Raw array data (dense ``jnp.ndarray`` or ``sparse.BCOO``). 

176 

177 Returns: 

178 A new instance of this implementation wrapping *data*. 

179 """ 

180 pass 

181 

182 @abstractmethod 

183 def matmul(self, other: "QarrayImpl") -> "QarrayImpl": 

184 """Matrix multiplication with *other*. 

185 

186 Args: 

187 other: Right-hand operand. 

188 

189 Returns: 

190 Result of ``self @ other`` as a ``QarrayImpl``. 

191 """ 

192 pass 

193 

194 @abstractmethod 

195 def add(self, other: "QarrayImpl") -> "QarrayImpl": 

196 """Element-wise addition with *other*. 

197 

198 Args: 

199 other: Right-hand operand. 

200 

201 Returns: 

202 Result of ``self + other`` as a ``QarrayImpl``. 

203 """ 

204 pass 

205 

206 @abstractmethod 

207 def sub(self, other: "QarrayImpl") -> "QarrayImpl": 

208 """Element-wise subtraction of *other*. 

209 

210 Args: 

211 other: Right-hand operand. 

212 

213 Returns: 

214 Result of ``self - other`` as a ``QarrayImpl``. 

215 """ 

216 pass 

217 

218 @abstractmethod 

219 def mul(self, scalar) -> "QarrayImpl": 

220 """Scalar multiplication. 

221 

222 Args: 

223 scalar: Scalar value to multiply by. 

224 

225 Returns: 

226 Result of ``scalar * self`` as a ``QarrayImpl``. 

227 """ 

228 pass 

229 

230 @abstractmethod 

231 def dag(self) -> "QarrayImpl": 

232 """Conjugate transpose. 

233 

234 Returns: 

235 The conjugate transpose of this array as a ``QarrayImpl``. 

236 """ 

237 pass 

238 

239 @abstractmethod 

240 def to_dense(self) -> "DenseImpl": 

241 """Convert to a ``DenseImpl``. 

242 

243 Returns: 

244 A ``DenseImpl`` wrapping the same data. 

245 """ 

246 pass 

247 

248 @abstractmethod 

249 def to_sparse_bcoo(self) -> "SparseBCOOImpl": 

250 """Convert to a ``SparseBCOOImpl`` (BCOO). 

251 

252 Returns: 

253 A ``SparseBCOOImpl`` wrapping the same data. 

254 """ 

255 pass 

256 

257 def to_sparse_dia(self) -> "QarrayImpl": 

258 """Convert to a ``SparseDiaImpl``. 

259 

260 Default implementation goes through dense and auto-detects diagonals. 

261 Subclasses may override for a more direct path. 

262 

263 Returns: 

264 A ``SparseDiaImpl`` wrapping the same data. 

265 """ 

266 # Import here to avoid circular imports at module load time 

267 from jaxquantum.core.sparse_dia import SparseDiaImpl 

268 return SparseDiaImpl.from_data(self.to_dense()._data) 

269 

270 @abstractmethod 

271 def shape(self) -> tuple: 

272 """Shape of the underlying data array. 

273 

274 Returns: 

275 Tuple of dimension sizes. 

276 """ 

277 pass 

278 

279 @abstractmethod 

280 def dtype(self): 

281 """Data type of the underlying array. 

282 

283 Returns: 

284 A numpy/JAX dtype object. 

285 """ 

286 pass 

287 

288 @abstractmethod 

289 def __deepcopy__(self, memo=None): 

290 pass 

291 

292 @abstractmethod 

293 def tidy_up(self, atol): 

294 """Zero out values whose magnitude is below *atol*. 

295 

296 Args: 

297 atol: Absolute tolerance threshold. 

298 

299 Returns: 

300 A new ``QarrayImpl`` with small values zeroed. 

301 """ 

302 pass 

303 

304 @abstractmethod 

305 def kron(self, other: "QarrayImpl") -> "QarrayImpl": 

306 """Kronecker (tensor) product with another implementation. 

307 

308 Args: 

309 other: Right-hand operand. Mixed-type pairs are handled by 

310 ``_coerce`` — the result has the higher ``PROMOTION_ORDER`` 

311 type (dense wins over sparse). 

312 

313 Returns: 

314 A new ``QarrayImpl`` containing the Kronecker product. 

315 """ 

316 pass 

317 

318 @classmethod 

319 @abstractmethod 

320 def _eye_data(cls, n: int, dtype=None): 

321 """Create identity matrix data of size n. 

322 

323 Args: 

324 n: Matrix size. 

325 dtype: Optional data type for the identity entries. 

326 

327 Returns: 

328 Raw identity matrix data in the format appropriate for this impl. 

329 """ 

330 pass 

331 

332 @classmethod 

333 @abstractmethod 

334 def can_handle_data(cls, arr) -> bool: 

335 """Return True if *arr* is a raw data type natively handled by this impl. 

336 

337 Used by the module-level :func:`dag_data` dispatcher to route raw 

338 arrays to the correct backend without any isinstance chain outside the 

339 impl classes. 

340 

341 Args: 

342 arr: Raw array — e.g. ``jnp.ndarray`` for ``DenseImpl`` or 

343 ``sparse.BCOO`` for ``SparseBCOOImpl``. 

344 

345 Returns: 

346 True if this impl can operate on *arr* without conversion. 

347 """ 

348 pass 

349 

350 @classmethod 

351 @abstractmethod 

352 def dag_data(cls, arr): 

353 """Conjugate transpose of raw data in this impl's native format. 

354 

355 Implementations must handle batched arrays (last two axes are 

356 swapped) and must not densify sparse arrays. 

357 

358 Args: 

359 arr: Raw array in this impl's native format. 

360 

361 Returns: 

362 Conjugate transpose with the last two axes swapped. 

363 """ 

364 pass 

365 

366 def _promote_to(self, target_cls: type) -> "QarrayImpl": 

367 """Convert this impl to *target_cls* by passing through dense. 

368 

369 Args: 

370 target_cls: The destination ``QarrayImpl`` subclass. 

371 

372 Returns: 

373 An instance of *target_cls* holding equivalent data. 

374 """ 

375 if isinstance(self, target_cls): 

376 return self 

377 return target_cls.from_data(self.to_dense()._data) 

378 

379 def _coerce(self, other: "QarrayImpl") -> "tuple[QarrayImpl, QarrayImpl]": 

380 """Coerce *self* and *other* to the same implementation type. 

381 

382 The impl type with the higher ``PROMOTION_ORDER`` wins; the other side 

383 is promoted via :meth:`_promote_to`. 

384 

385 Args: 

386 other: The other operand. 

387 

388 Returns: 

389 A pair ``(a, b)`` of the same ``QarrayImpl`` subclass, suitable 

390 for a binary operation. 

391 """ 

392 if type(self) is type(other): 

393 return self, other 

394 if self.PROMOTION_ORDER >= other.PROMOTION_ORDER: 

395 return self, other._promote_to(type(self)) 

396 return self._promote_to(type(other)), other 

397 

398 

399@struct.dataclass 

400class DenseImpl(QarrayImpl): 

401 """Dense implementation using JAX dense arrays. 

402 

403 Attributes: 

404 _data: The underlying ``jnp.ndarray``. 

405 """ 

406 

407 _data: Array 

408 

409 PROMOTION_ORDER = 2 # noqa: RUF012 — not a struct field; no annotation intentional 

410 

411 @classmethod 

412 def from_data(cls, data) -> "DenseImpl": 

413 """Wrap *data* in a new ``DenseImpl``. 

414 

415 Args: 

416 data: Array-like input data. 

417 

418 Returns: 

419 A ``DenseImpl`` wrapping ``robust_asarray(data)``. 

420 """ 

421 return cls(_data=robust_asarray(data)) 

422 

423 def get_data(self) -> Array: 

424 """Return the underlying dense array.""" 

425 return self._data 

426 

427 def matmul(self, other: QarrayImpl) -> QarrayImpl: 

428 """Matrix multiply ``self @ other``, coercing types as needed. 

429 

430 Args: 

431 other: Right-hand operand. 

432 

433 Returns: 

434 A ``DenseImpl`` containing the matrix product. 

435 """ 

436 a, b = self._coerce(other) 

437 if a is not self: 

438 return a.matmul(b) 

439 return DenseImpl(self._data @ b._data) 

440 

441 def add(self, other: QarrayImpl) -> QarrayImpl: 

442 """Element-wise addition ``self + other``, coercing types as needed. 

443 

444 Args: 

445 other: Right-hand operand. 

446 

447 Returns: 

448 A ``DenseImpl`` containing the sum. 

449 """ 

450 a, b = self._coerce(other) 

451 if a is not self: 

452 return a.add(b) 

453 return DenseImpl(self._data + b._data) 

454 

455 def sub(self, other: QarrayImpl) -> QarrayImpl: 

456 """Element-wise subtraction ``self - other``, coercing types as needed. 

457 

458 Args: 

459 other: Right-hand operand. 

460 

461 Returns: 

462 A ``DenseImpl`` containing the difference. 

463 """ 

464 a, b = self._coerce(other) 

465 if a is not self: 

466 return a.sub(b) 

467 return DenseImpl(self._data - b._data) 

468 

469 def mul(self, scalar) -> QarrayImpl: 

470 """Scalar multiplication. 

471 

472 Args: 

473 scalar: Scalar value. 

474 

475 Returns: 

476 A ``DenseImpl`` with each element multiplied by *scalar*. 

477 """ 

478 return DenseImpl(scalar * self._data) 

479 

480 def dag(self) -> QarrayImpl: 

481 """Conjugate transpose. 

482 

483 Returns: 

484 A ``DenseImpl`` containing the conjugate transpose. 

485 """ 

486 return DenseImpl(jnp.moveaxis(jnp.conj(self._data), -1, -2)) 

487 

488 def to_dense(self) -> "DenseImpl": 

489 """Return self (already dense). 

490 

491 Returns: 

492 This ``DenseImpl`` instance unchanged. 

493 """ 

494 return self 

495 

496 def to_sparse_bcoo(self) -> "SparseBCOOImpl": 

497 """Convert to a ``SparseBCOOImpl`` via ``BCOO.fromdense``. 

498 

499 Returns: 

500 A ``SparseBCOOImpl`` wrapping a BCOO conversion of this array. 

501 """ 

502 from jaxquantum.core.sparse_bcoo import SparseBCOOImpl 

503 return SparseBCOOImpl(sparse.BCOO.fromdense(self._data)) 

504 

505 def shape(self) -> tuple: 

506 """Shape of the underlying dense array. 

507 

508 Returns: 

509 Tuple of dimension sizes. 

510 """ 

511 return self._data.shape 

512 

513 def dtype(self): 

514 """Data type of the underlying dense array. 

515 

516 Returns: 

517 The dtype of ``_data``. 

518 """ 

519 return self._data.dtype 

520 

521 def frobenius_norm(self) -> float: 

522 """Compute the Frobenius norm. 

523 

524 Returns: 

525 The Frobenius norm as a scalar. 

526 """ 

527 return jnp.sqrt(jnp.sum(jnp.abs(self._data) ** 2)) 

528 

529 def real(self) -> QarrayImpl: 

530 """Element-wise real part. 

531 

532 Returns: 

533 A ``DenseImpl`` containing the real parts. 

534 """ 

535 return DenseImpl(jnp.real(self._data)) 

536 

537 def imag(self) -> QarrayImpl: 

538 """Element-wise imaginary part. 

539 

540 Returns: 

541 A ``DenseImpl`` containing the imaginary parts. 

542 """ 

543 return DenseImpl(jnp.imag(self._data)) 

544 

545 def conj(self) -> QarrayImpl: 

546 """Element-wise complex conjugate. 

547 

548 Returns: 

549 A ``DenseImpl`` containing the complex-conjugated values. 

550 """ 

551 return DenseImpl(jnp.conj(self._data)) 

552 

553 def __deepcopy__(self, memo=None): 

554 return DenseImpl( 

555 _data=deepcopy(self._data, memo) 

556 ) 

557 

558 def tidy_up(self, atol): 

559 """Zero out real/imaginary parts whose magnitude is below *atol*. 

560 

561 Args: 

562 atol: Absolute tolerance threshold. 

563 

564 Returns: 

565 A new ``DenseImpl`` with small values zeroed. 

566 """ 

567 data = self._data 

568 data_re = jnp.real(data) 

569 data_im = jnp.imag(data) 

570 data_re_mask = jnp.abs(data_re) > atol 

571 data_im_mask = jnp.abs(data_im) > atol 

572 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask 

573 

574 return DenseImpl( 

575 _data=data_new 

576 ) 

577 

578 def kron(self, other: "QarrayImpl") -> "QarrayImpl": 

579 """Kronecker product using ``jnp.kron``. 

580 

581 Args: 

582 other: Right-hand operand. 

583 

584 Returns: 

585 A ``DenseImpl`` containing the Kronecker product. 

586 """ 

587 a, b = self._coerce(other) 

588 if a is not self: 

589 return a.kron(b) 

590 return DenseImpl(jnp.kron(self._data, b._data)) 

591 

592 @classmethod 

593 def _eye_data(cls, n: int, dtype=None): 

594 """Create an ``n x n`` identity matrix as a dense JAX array. 

595 

596 Args: 

597 n: Matrix size. 

598 dtype: Optional data type. 

599 

600 Returns: 

601 A ``jnp.ndarray`` identity matrix of shape ``(n, n)``. 

602 """ 

603 return jnp.eye(n, dtype=dtype) 

604 

605 @classmethod 

606 def can_handle_data(cls, arr) -> bool: 

607 """Return True for any non-BCOO, non-SparseDIA array. 

608 

609 ``SparseDiaData`` objects carry a ``_is_sparse_dia`` marker so we can 

610 exclude them without a direct type import (which would be circular). 

611 

612 Args: 

613 arr: Raw array. 

614 

615 Returns: 

616 True when *arr* is a plain dense array (not BCOO, not SparseDiaData). 

617 """ 

618 return not isinstance(arr, sparse.BCOO) and not getattr(arr, "_is_sparse_dia", False) 

619 

620 @classmethod 

621 def dag_data(cls, arr) -> Array: 

622 """Conjugate transpose for dense arrays. 

623 

624 Swaps the last two axes via :func:`jnp.moveaxis` and conjugates all 

625 elements. For 1-D inputs only conjugation is applied. 

626 

627 Args: 

628 arr: Dense array. 

629 

630 Returns: 

631 Conjugate transpose with the last two axes swapped. 

632 """ 

633 if len(arr.shape) == 1: 

634 return jnp.conj(arr) 

635 return jnp.moveaxis(jnp.conj(arr), -1, -2) 

636 

637 

638# Register implementation classes with the enum registry 

639# SparseBCOOImpl is registered in sparse_bcoo.py after import 

640QarrayImplType.register(DenseImpl, QarrayImplType.DENSE) 

641 

642 

643@struct.dataclass 

644class Qarray(Generic[ImplT]): 

645 """Quantum array with a pluggable storage backend. 

646 

647 ``Qarray`` wraps a ``QarrayImpl`` together with quantum-mechanical 

648 dimension metadata (``_qdims``) and optional batch dimensions 

649 (``_bdims``). The default backend is dense (``DenseImpl``); pass 

650 ``implementation="sparse_bcoo"`` (or ``QarrayImplType.SPARSE_BCOO``) to 

651 store data as a JAX BCOO sparse array. 

652 

653 Attributes: 

654 _impl: The storage backend holding the raw data. 

655 _qdims: Quantum dimension metadata (bra/ket structure, Hilbert space 

656 sizes). 

657 _bdims: Tuple of batch dimension sizes (empty tuple = non-batched). 

658 

659 Example: 

660 >>> import jaxquantum as jqt 

661 >>> a = jqt.destroy(10, implementation="sparse_bcoo") 

662 >>> a.is_sparse_bcoo 

663 True 

664 """ 

665 

666 _impl: ImplT 

667 _qdims: Qdims = struct.field(pytree_node=False) 

668 _bdims: tuple[int] = struct.field(pytree_node=False) 

669 

670 # Initialization ---- 

671 @classmethod 

672 @overload 

673 def create(cls, data, dims=None, bdims=None, implementation: Literal[QarrayImplType.DENSE] = QarrayImplType.DENSE) -> "Qarray[DenseImpl]": 

674 ... 

675 

676 @classmethod 

677 @overload 

678 def create(cls, data, dims=None, bdims=None, implementation: Literal[QarrayImplType.SPARSE_BCOO] = ...) -> "Qarray[SparseBCOOImpl]": 

679 ... 

680 

681 @classmethod 

682 @overload 

683 def create(cls, data, dims=None, bdims=None, implementation=...) -> "Qarray[DenseImpl]": 

684 ... 

685 

686 @classmethod 

687 def create(cls, data, dims=None, bdims=None, implementation=QarrayImplType.DENSE): 

688 """Create a ``Qarray`` from raw data. 

689 

690 Handles shape normalisation, dimension inference, and tidying of small 

691 values. 

692 

693 Args: 

694 data: Input data array (dense array-like or ``sparse.BCOO``). 

695 dims: Quantum dimensions as ``((row_dims...), (col_dims...))``. 

696 Inferred from *data* shape when ``None``. 

697 bdims: Tuple of batch dimension sizes. Inferred from the leading 

698 dimensions of *data* when ``None``. 

699 implementation: Storage backend — ``QarrayImplType.DENSE`` 

700 (default) or ``QarrayImplType.SPARSE_BCOO``, or the equivalent 

701 string ``"dense"`` / ``"sparse_bcoo"``. 

702 

703 Returns: 

704 A new ``Qarray`` backed by the requested implementation. 

705 """ 

706 # Step 1: Prepare data ---- 

707 data = robust_asarray(data) 

708 

709 if len(data.shape) == 1 and data.shape[0] > 0: 

710 data = data.reshape(data.shape[0], 1) 

711 

712 if len(data.shape) >= 2: 

713 if data.shape[-2] != data.shape[-1] and not ( 

714 data.shape[-2] == 1 or data.shape[-1] == 1 

715 ): 

716 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

717 

718 if bdims is not None: 

719 if len(data.shape) - len(bdims) == 1: 

720 data = data.reshape(*data.shape[:-1], data.shape[-1], 1) 

721 # ---- 

722 

723 # Step 2: Prepare dimensions ---- 

724 if bdims is None: 

725 bdims = tuple(data.shape[:-2]) 

726 

727 if dims is None: 

728 dims = ((data.shape[-2],), (data.shape[-1],)) 

729 

730 if not isinstance(dims[0], (list, tuple)): 

731 # This handles the case where only the hilbert space dimensions are sent in. 

732 if data.shape[-1] == 1: 

733 dims = (tuple(dims), tuple([1 for _ in dims])) 

734 elif data.shape[-2] == 1: 

735 dims = (tuple([1 for _ in dims]), tuple(dims)) 

736 else: 

737 dims = (tuple(dims), tuple(dims)) 

738 else: 

739 dims = (tuple(dims[0]), tuple(dims[1])) 

740 

741 check_dims(dims, bdims, data.shape) 

742 

743 qdims = Qdims(dims) 

744 

745 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill. 

746 # It increases the compilation time, but only very slightly 

747 # increased the runtime of the jit compiled function. 

748 # We could instead use this tidy up where we think we need it. 

749 

750 impl_class = QarrayImplType(implementation).get_impl_class() 

751 impl = impl_class.from_data(data) 

752 impl = impl.tidy_up(SETTINGS["auto_tidyup_atol"]) 

753 

754 return cls(impl, qdims, bdims) 

755 

756 @classmethod 

757 @overload 

758 def from_sparse_bcoo(cls, data, dims=None, bdims=None) -> "Qarray[SparseBCOOImpl]": 

759 ... 

760 

761 @classmethod 

762 def from_sparse_bcoo(cls, data, dims=None, bdims=None): 

763 """Create a ``Qarray`` directly from a sparse BCOO array without densifying. 

764 

765 Args: 

766 data: A ``sparse.BCOO`` or array-like to store as sparse BCOO. 

767 dims: Quantum dimensions. Inferred when ``None``. 

768 bdims: Batch dimensions. Inferred when ``None``. 

769 

770 Returns: 

771 A ``Qarray[SparseBCOOImpl]``. 

772 """ 

773 return cls.create(data, dims=dims, bdims=bdims, implementation=QarrayImplType.SPARSE_BCOO) 

774 

775 @classmethod 

776 def from_sparse_dia(cls, data, dims=None, bdims=None) -> "Qarray": 

777 """Create a SparseDIA-backed ``Qarray``. 

778 

779 Accepts either a dense array-like (diagonals are auto-detected) or a 

780 :class:`~jaxquantum.core.sparse_dia.SparseDiaData` container. 

781 

782 Args: 

783 data: Dense array of shape (*batch, n, n) or a ``SparseDiaData``. 

784 dims: Quantum dimensions ``((row_dims,), (col_dims,))``. 

785 bdims: Batch dimension sizes. 

786 

787 Returns: 

788 A ``Qarray`` backed by ``SparseDiaImpl``. 

789 """ 

790 return cls.create(data, dims=dims, bdims=bdims, implementation=QarrayImplType.SPARSE_DIA) 

791 

792 @classmethod 

793 @overload 

794 def from_list(cls, qarr_list: List["Qarray[DenseImpl]"]) -> "Qarray[DenseImpl]": 

795 ... 

796 

797 @classmethod 

798 @overload 

799 def from_list(cls, qarr_list: List["Qarray[SparseBCOOImpl]"]) -> "Qarray[SparseBCOOImpl]": 

800 ... 

801 

802 @classmethod 

803 def from_list(cls, qarr_list: List[Qarray]) -> Qarray: 

804 """Create a batched ``Qarray`` from a list of same-shaped ``Qarray`` objects. 

805 

806 The output implementation is determined by the element with the highest 

807 ``PROMOTION_ORDER``: if all inputs are sparse the result is sparse; if 

808 any input is dense (or types are mixed) all inputs are promoted to dense 

809 and the result is dense. 

810 

811 Args: 

812 qarr_list: List of ``Qarray`` objects with identical ``dims`` and 

813 ``bdims``. May be empty. 

814 

815 Returns: 

816 A ``Qarray`` with an extra leading batch dimension of size 

817 ``len(qarr_list)``. 

818 

819 Raises: 

820 ValueError: If the elements have mismatched ``dims`` or ``bdims``. 

821 """ 

822 if len(qarr_list) == 0: 

823 dims = ((), ()) 

824 bdims = (0,) 

825 return cls.create(jnp.array([]), dims=dims, bdims=bdims) 

826 

827 dims = qarr_list[0].dims 

828 bdims = qarr_list[0].bdims 

829 

830 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list): 

831 raise ValueError("All Qarrays in the list must have the same dimensions.") 

832 

833 new_bdims = (len(qarr_list),) + bdims 

834 

835 # Pick the target type: highest PROMOTION_ORDER wins (dense beats sparse). 

836 target_impl_type = max( 

837 (q.impl_type for q in qarr_list), 

838 key=lambda t: t.get_impl_class().PROMOTION_ORDER, 

839 ) 

840 

841 if target_impl_type == QarrayImplType.SPARSE_DIA: 

842 # All inputs are SparseDIA — batch without densifying. 

843 # Compute union of offsets across all operators, then remap each 

844 # operator's _diags rows into the union shape and stack. 

845 from jaxquantum.core.sparse_dia import SparseDiaData # lazy to avoid circular 

846 union_offsets = tuple(sorted( 

847 set().union(*[set(q._impl._offsets) for q in qarr_list]) 

848 )) 

849 union_idx = {k: i for i, k in enumerate(union_offsets)} 

850 n = qarr_list[0]._impl._diags.shape[-1] 

851 dtype = jnp.result_type(*[q._impl._diags.dtype for q in qarr_list]) 

852 remapped = [] 

853 for q in qarr_list: 

854 row = jnp.zeros((len(union_offsets), n), dtype=dtype) 

855 for i_src, k in enumerate(q._impl._offsets): 

856 row = row.at[union_idx[k], :].set(q._impl._diags[i_src, :]) 

857 remapped.append(row) 

858 stacked = jnp.stack(remapped, axis=0) # (n_ops, n_union_diags, N) 

859 raw = SparseDiaData(offsets=union_offsets, diags=stacked) 

860 return cls.create(raw, dims=dims, bdims=new_bdims, implementation=QarrayImplType.SPARSE_DIA) 

861 

862 if target_impl_type == QarrayImplType.SPARSE_BCOO: 

863 # All inputs are sparse BCOO — stack via dense intermediates then re-sparsify. 

864 data = jnp.array([q.data.todense() for q in qarr_list]) 

865 return cls.create(data, dims=dims, bdims=new_bdims, implementation=QarrayImplType.SPARSE_BCOO) 

866 

867 # Target is dense: promote any sparse inputs before stacking. 

868 data = jnp.array([q.to_dense().data for q in qarr_list]) 

869 return cls.create(data, dims=dims, bdims=new_bdims, implementation=QarrayImplType.DENSE) 

870 

871 @classmethod 

872 @overload 

873 def from_array(cls, qarr_arr: "Qarray[DenseImpl]") -> "Qarray[DenseImpl]": 

874 ... 

875 

876 @classmethod 

877 @overload 

878 def from_array(cls, qarr_arr: "Qarray[SparseBCOOImpl]") -> "Qarray[SparseBCOOImpl]": 

879 ... 

880 

881 @classmethod 

882 def from_array(cls, qarr_arr) -> Qarray: 

883 """Create a ``Qarray`` from a (possibly nested) list of ``Qarray`` objects. 

884 

885 Args: 

886 qarr_arr: A ``Qarray`` (returned as-is) or a nested list of 

887 ``Qarray`` objects. 

888 

889 Returns: 

890 A ``Qarray`` with batch dimensions matching the nesting structure 

891 of *qarr_arr*. 

892 """ 

893 if isinstance(qarr_arr, Qarray): 

894 return qarr_arr 

895 

896 bdims = () 

897 lvl = qarr_arr 

898 while not isinstance(lvl, Qarray): 

899 bdims = bdims + (len(lvl),) 

900 if len(lvl) > 0: 

901 lvl = lvl[0] 

902 else: 

903 break 

904 

905 def flat(lis): 

906 flatList = [] 

907 for element in lis: 

908 if type(element) is list: 

909 flatList += flat(element) 

910 else: 

911 flatList.append(element) 

912 return flatList 

913 

914 qarr_list = flat(qarr_arr) 

915 qarr = cls.from_list(qarr_list) 

916 qarr = qarr.reshape_bdims(*bdims) 

917 return qarr 

918 

919 # Properties ---- 

920 @property 

921 def qtype(self): 

922 """Quantum type of this array (ket, bra, or operator).""" 

923 return self._qdims.qtype 

924 

925 @property 

926 def dtype(self): 

927 """Data type of the underlying storage array.""" 

928 return self._impl.dtype() 

929 

930 @property 

931 def dims(self): 

932 """Quantum dimensions as ``((row_dims...), (col_dims...))``.""" 

933 return self._qdims.dims 

934 

935 @property 

936 def bdims(self): 

937 """Tuple of batch dimension sizes (empty tuple = non-batched).""" 

938 return self._bdims 

939 

940 @property 

941 def qdims(self): 

942 """The ``Qdims`` metadata object for this array.""" 

943 return self._qdims 

944 

945 @property 

946 def space_dims(self): 

947 """Hilbert space dimensions for the relevant side (ket row / bra col).""" 

948 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

949 return self.dims[0] 

950 elif self.qtype == Qtypes.bra: 

951 return self.dims[1] 

952 else: 

953 # TODO: not reached for some reason 

954 raise ValueError("Unsupported qtype.") 

955 

956 @property 

957 def data(self): 

958 """The raw underlying data (dense ``jnp.ndarray`` or ``sparse.BCOO``).""" 

959 return self._impl.data 

960 

961 @property 

962 def shaped_data(self): 

963 """Data reshaped to ``bdims + dims[0] + dims[1]``.""" 

964 return self.data.reshape(self.bdims + self.dims[0] + self.dims[1]) 

965 

966 @property 

967 def shape(self): 

968 """Shape of the underlying data array.""" 

969 return self.data.shape 

970 

971 @property 

972 def is_batched(self): 

973 """True if this array has one or more batch dimensions.""" 

974 return len(self.bdims) > 0 

975 

976 @property 

977 def is_sparse_bcoo(self): 

978 """True if the storage backend is ``SparseBCOOImpl`` (BCOO).""" 

979 return self._impl.impl_type == QarrayImplType.SPARSE_BCOO 

980 

981 @property 

982 def is_dense(self): 

983 """True if the storage backend is ``DenseImpl``.""" 

984 return self._impl.impl_type == QarrayImplType.DENSE 

985 

986 @property 

987 def is_sparse_dia(self): 

988 """True if the storage backend is ``SparseDiaImpl``.""" 

989 return self._impl.impl_type == QarrayImplType.SPARSE_DIA 

990 

991 @property 

992 def impl_type(self): 

993 """The ``QarrayImplType`` member of the current storage backend.""" 

994 return self._impl.impl_type 

995 

996 def to_sparse_bcoo(self) -> "Qarray[SparseBCOOImpl]": 

997 """Return a BCOO-sparse-backed copy of this array. 

998 

999 If the array is already sparse BCOO, returns self unchanged. 

1000 

1001 Returns: 

1002 A ``Qarray[SparseBCOOImpl]``. 

1003 """ 

1004 if self.is_sparse_bcoo: 

1005 return self 

1006 new_impl = self._impl.to_sparse_bcoo() 

1007 return Qarray(new_impl, self._qdims, self._bdims) 

1008 

1009 def to_sparse_dia(self) -> "Qarray": 

1010 """Return a SparseDIA-backed copy of this array. 

1011 

1012 If the array is already SparseDIA, returns self unchanged. 

1013 

1014 Returns: 

1015 A ``Qarray[SparseDiaImpl]``. 

1016 """ 

1017 if self.is_sparse_dia: 

1018 return self 

1019 new_impl = self._impl.to_sparse_dia() 

1020 return Qarray(new_impl, self._qdims, self._bdims) 

1021 

1022 def to_dense(self) -> "Qarray[DenseImpl]": 

1023 """Return a dense-backed copy of this array. 

1024 

1025 If the array is already dense, returns self unchanged. 

1026 

1027 Returns: 

1028 A ``Qarray[DenseImpl]``. 

1029 """ 

1030 if self.is_dense: 

1031 return self 

1032 new_impl = self._impl.to_dense() 

1033 return Qarray(new_impl, self._qdims, self._bdims) 

1034 

1035 def __getitem__(self, index): 

1036 if len(self.bdims) > 0: 

1037 return Qarray.create( 

1038 self.data[index], 

1039 dims=self.dims, 

1040 implementation=self.impl_type, 

1041 ) 

1042 else: 

1043 raise ValueError("Cannot index a non-batched Qarray.") 

1044 

1045 def reshape_bdims(self, *args): 

1046 """Reshape the batch dimensions of this ``Qarray``. 

1047 

1048 Args: 

1049 *args: New batch dimension sizes. 

1050 

1051 Returns: 

1052 A new ``Qarray`` with the requested batch shape. 

1053 """ 

1054 new_bdims = tuple(args) 

1055 

1056 if prod(new_bdims) == 0: 

1057 new_shape = new_bdims 

1058 else: 

1059 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,) 

1060 

1061 # Preserve implementation type 

1062 implementation = self.impl_type 

1063 return Qarray.create( 

1064 self.data.reshape(new_shape), 

1065 dims=self.dims, 

1066 bdims=new_bdims, 

1067 implementation=implementation, 

1068 ) 

1069 

1070 def space_to_qdims(self, space_dims: List[int]): 

1071 """Convert Hilbert space dimensions to full quantum dims tuple. 

1072 

1073 Args: 

1074 space_dims: Sequence of per-subsystem Hilbert space sizes, or a 

1075 full ``((row_dims), (col_dims))`` tuple (returned unchanged). 

1076 

1077 Returns: 

1078 A ``((row_dims...), (col_dims...))`` tuple. 

1079 

1080 Raises: 

1081 ValueError: If ``self.qtype`` is not ket, bra, or oper. 

1082 """ 

1083 if isinstance(space_dims[0], (list, tuple)): 

1084 return space_dims 

1085 

1086 if self.qtype in [Qtypes.oper, Qtypes.ket]: 

1087 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))])) 

1088 elif self.qtype == Qtypes.bra: 

1089 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims)) 

1090 else: 

1091 raise ValueError("Unsupported qtype for space_to_qdims conversion.") 

1092 

1093 def reshape_qdims(self, *args): 

1094 """Reshape the quantum dimensions of the Qarray. 

1095 

1096 Note that this does not take in qdims but rather the new Hilbert space 

1097 dimensions. 

1098 

1099 Args: 

1100 *args: New Hilbert dimensions for the Qarray. 

1101 

1102 Returns: 

1103 Qarray: reshaped Qarray. 

1104 """ 

1105 

1106 new_space_dims = tuple(args) 

1107 current_space_dims = self.space_dims 

1108 assert prod(new_space_dims) == prod(current_space_dims) 

1109 

1110 new_qdims = self.space_to_qdims(new_space_dims) 

1111 new_bdims = self.bdims 

1112 

1113 # Preserve implementation type 

1114 implementation = self.impl_type 

1115 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims, implementation=implementation) 

1116 

1117 def resize(self, new_shape): 

1118 """Resize the Qarray to a new shape. 

1119 

1120 TODO: review and maybe deprecate this method. 

1121 

1122 Args: 

1123 new_shape: Target shape tuple. 

1124 

1125 Returns: 

1126 A new ``Qarray`` with data resized via ``jnp.resize``. 

1127 """ 

1128 dims = self.dims 

1129 data = jnp.resize(self.data, new_shape) 

1130 # Preserve implementation type 

1131 implementation = self.impl_type 

1132 return Qarray.create( 

1133 data, 

1134 dims=dims, 

1135 implementation=implementation, 

1136 ) 

1137 

1138 def __len__(self): 

1139 """Length along the first batch dimension. 

1140 

1141 Returns: 

1142 Size of the leading batch dimension. 

1143 

1144 Raises: 

1145 ValueError: If the array is not batched. 

1146 """ 

1147 if len(self.bdims) > 0: 

1148 return self.data.shape[0] 

1149 else: 

1150 raise ValueError("Cannot get length of a non-batched Qarray.") 

1151 

1152 def __eq__(self, other): 

1153 if not isinstance(other, Qarray): 

1154 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.") 

1155 

1156 if self.dims != other.dims: 

1157 return False 

1158 

1159 if self.bdims != other.bdims: 

1160 return False 

1161 

1162 if self.is_sparse_bcoo and other.is_sparse_bcoo: 

1163 # Fast structural path: same sparsity pattern → compare values only (no todense) 

1164 if (self.data.indices.shape == other.data.indices.shape 

1165 and bool(jnp.all(self.data.indices == other.data.indices))): 

1166 return bool(jnp.allclose(self.data.data, other.data.data)) 

1167 # Different patterns: fall back to dense comparison (unavoidable) 

1168 return bool(jnp.all(self.data.todense() == other.data.todense())) 

1169 

1170 # At least one dense: convert sparse side to dense for comparison 

1171 self_data = self.data.todense() if hasattr(self.data, 'todense') else self.data 

1172 other_data = other.data.todense() if hasattr(other.data, 'todense') else other.data 

1173 return bool(jnp.all(self_data == other_data)) 

1174 

1175 def __ne__(self, other): 

1176 return not self.__eq__(other) 

1177 

1178 # Elementary Math ---- 

1179 def __matmul__(self, other): 

1180 if not isinstance(other, Qarray): 

1181 return NotImplemented 

1182 

1183 _qdims_new = self._qdims @ other._qdims 

1184 new_impl = self._impl.matmul(other._impl) 

1185 

1186 return Qarray.create( 

1187 new_impl.data, 

1188 dims=_qdims_new.dims, 

1189 implementation=new_impl.impl_type, 

1190 ) 

1191 

1192 def __mul__(self, other): 

1193 if isinstance(other, Qarray): 

1194 return self.__matmul__(other) 

1195 

1196 other = other + 0.0j 

1197 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

1198 other = other.reshape(other.shape + (1, 1)) 

1199 

1200 new_impl = self._impl.mul(other) 

1201 return Qarray.create( 

1202 new_impl.data, 

1203 dims=self._qdims.dims, 

1204 implementation=new_impl.impl_type, 

1205 ) 

1206 

1207 def __rmul__(self, other): 

1208 return self.__mul__(other) 

1209 

1210 def __neg__(self): 

1211 return self.__mul__(-1) 

1212 

1213 def __truediv__(self, other): 

1214 """Divide by a scalar. 

1215 

1216 Args: 

1217 other: Scalar divisor. 

1218 

1219 Returns: 

1220 A new ``Qarray`` with all elements divided by *other*. 

1221 

1222 Raises: 

1223 ValueError: If *other* is a ``Qarray``. 

1224 """ 

1225 if isinstance(other, Qarray): 

1226 raise ValueError("Cannot divide a Qarray by another Qarray.") 

1227 

1228 return self.__mul__(1 / other) 

1229 

1230 def __add__(self, other): 

1231 if isinstance(other, Qarray): 

1232 if self.dims != other.dims: 

1233 msg = ( 

1234 "Dimensions are incompatible: " 

1235 + repr(self.dims) 

1236 + " and " 

1237 + repr(other.dims) 

1238 ) 

1239 raise ValueError(msg) 

1240 new_impl = self._impl.add(other._impl) 

1241 return Qarray.create( 

1242 new_impl.data, 

1243 dims=self.dims, 

1244 implementation=new_impl.impl_type, 

1245 ) 

1246 

1247 if robust_isscalar(other) and other == 0: 

1248 return self.copy() 

1249 

1250 if self.data.shape[-2] == self.data.shape[-1]: 

1251 other = other + 0.0j 

1252 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

1253 other = other.reshape(other.shape + (1, 1)) 

1254 eye_data = self._impl._eye_data(self.data.shape[-2], dtype=self.data.dtype) 

1255 other = Qarray.create( 

1256 other * eye_data, 

1257 dims=self.dims, 

1258 implementation=self.impl_type 

1259 ) 

1260 return self.__add__(other) 

1261 

1262 return NotImplemented 

1263 

1264 def __radd__(self, other): 

1265 return self.__add__(other) 

1266 

1267 def __sub__(self, other): 

1268 if isinstance(other, Qarray): 

1269 if self.dims != other.dims: 

1270 msg = ( 

1271 "Dimensions are incompatible: " 

1272 + repr(self.dims) 

1273 + " and " 

1274 + repr(other.dims) 

1275 ) 

1276 raise ValueError(msg) 

1277 new_impl = self._impl.sub(other._impl) 

1278 return Qarray.create( 

1279 new_impl.data, 

1280 dims=self.dims, 

1281 implementation=new_impl.impl_type, 

1282 ) 

1283 

1284 if robust_isscalar(other) and other == 0: 

1285 return self.copy() 

1286 

1287 if self.data.shape[-2] == self.data.shape[-1]: 

1288 other = other + 0.0j 

1289 

1290 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar 

1291 other = other.reshape(other.shape + (1, 1)) 

1292 eye_data = self._impl._eye_data(self.data.shape[-2], dtype=self.data.dtype) 

1293 other = Qarray.create( 

1294 other * eye_data, 

1295 dims=self.dims, 

1296 implementation=self.impl_type 

1297 ) 

1298 return self.__sub__(other) 

1299 

1300 return NotImplemented 

1301 

1302 def __rsub__(self, other): 

1303 return self.__neg__().__add__(other) 

1304 

1305 def __xor__(self, other): 

1306 if not isinstance(other, Qarray): 

1307 return NotImplemented 

1308 return tensor(self, other) 

1309 

1310 def __rxor__(self, other): 

1311 if not isinstance(other, Qarray): 

1312 return NotImplemented 

1313 return tensor(other, self) 

1314 

1315 def __pow__(self, other): 

1316 if not isinstance(other, int): 

1317 return NotImplemented 

1318 

1319 return powm(self, other) 

1320 

1321 # String Representation ---- 

1322 def _str_header(self): 

1323 """Build the one-line header string for ``__str__`` and ``__repr__``.""" 

1324 impl_type = self.impl_type.value 

1325 out = ", ".join( 

1326 [ 

1327 "Quantum array: dims = " + str(self.dims), 

1328 "bdims = " + str(self.bdims), 

1329 "shape = " + str(self.data.shape), 

1330 "type = " + str(self.qtype), 

1331 "impl = " + impl_type, 

1332 ] 

1333 ) 

1334 return out 

1335 

1336 def __str__(self): 

1337 return self._str_header() + "\nQarray data =\n" + str(self.data) 

1338 

1339 @property 

1340 def header(self): 

1341 """One-line header string describing dimensions, shape, and backend.""" 

1342 return self._str_header() 

1343 

1344 def __repr__(self): 

1345 return self.__str__() 

1346 

1347 # Utilities ---- 

1348 def copy(self, memo=None): 

1349 """Return a deep copy of this ``Qarray``. 

1350 

1351 Args: 

1352 memo: Optional memo dict forwarded to ``deepcopy``. 

1353 

1354 Returns: 

1355 A new ``Qarray`` with independent copies of all data. 

1356 """ 

1357 return self.__deepcopy__(memo) 

1358 

1359 def __deepcopy__(self, memo): 

1360 """Need to override this when defining __getattr__.""" 

1361 

1362 return Qarray( 

1363 _impl=deepcopy(self._impl, memo=memo), 

1364 _qdims=deepcopy(self._qdims, memo=memo), 

1365 _bdims=deepcopy(self._bdims, memo=memo), 

1366 ) 

1367 

1368 def __getattr__(self, method_name): 

1369 if "__" == method_name[:2]: 

1370 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__ 

1371 return lambda *args, **kwargs: NotImplemented 

1372 

1373 modules = [jnp, jnp.linalg, jsp, jsp.linalg] 

1374 

1375 method_f = None 

1376 for mod in modules: 

1377 method_f = getattr(mod, method_name, None) 

1378 if method_f is not None: 

1379 break 

1380 

1381 if method_f is None: 

1382 raise NotImplementedError( 

1383 f"Method {method_name} does not exist. No backup method found in {modules}." 

1384 ) 

1385 

1386 def func(*args, **kwargs): 

1387 # For operations that might not be supported in sparse, convert to dense 

1388 if self.is_sparse_bcoo: 

1389 dense_self = self.to_dense() 

1390 res = method_f(dense_self.data, *args, **kwargs) 

1391 else: 

1392 res = method_f(self.data, *args, **kwargs) 

1393 

1394 if getattr(res, "shape", None) is None or res.shape != self.data.shape: 

1395 return res 

1396 else: 

1397 # Preserve implementation type 

1398 return Qarray.create(res, dims=self._qdims.dims, implementation=self.impl_type) 

1399 

1400 return func 

1401 

1402 # Conversions / Reshaping ---- 

1403 def dag(self): 

1404 """Conjugate transpose of this array.""" 

1405 return dag(self) 

1406 

1407 def to_dm(self): 

1408 """Convert a ket to a density matrix via outer product.""" 

1409 return ket2dm(self) 

1410 

1411 def is_dm(self): 

1412 """Return True if this array is an operator (density-matrix type).""" 

1413 return self.qtype == Qtypes.oper 

1414 

1415 def is_vec(self): 

1416 """Return True if this array is a ket or bra.""" 

1417 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra 

1418 

1419 def to_ket(self): 

1420 """Convert a bra to a ket (no-op for kets).""" 

1421 return to_ket(self) 

1422 

1423 def transpose(self, *args): 

1424 """Transpose subsystem indices.""" 

1425 return transpose(self, *args) 

1426 

1427 def keep_only_diag_elements(self): 

1428 """Zero out all off-diagonal elements.""" 

1429 return keep_only_diag_elements(self) 

1430 

1431 # Math Functions ---- 

1432 def unit(self): 

1433 """Return the normalised (unit-norm) version of this array.""" 

1434 return unit(self) 

1435 

1436 def norm(self): 

1437 """Compute the norm of this array.""" 

1438 return norm(self) 

1439 

1440 def frobenius_norm(self): 

1441 """Compute the Frobenius norm directly from the implementation. 

1442 

1443 Returns: 

1444 The Frobenius norm as a scalar. 

1445 """ 

1446 return self._impl.frobenius_norm() 

1447 

1448 def real(self): 

1449 """Element-wise real part. 

1450 

1451 Returns: 

1452 A new ``Qarray`` containing the real parts of each element. 

1453 """ 

1454 new_impl = self._impl.real() 

1455 return Qarray.create( 

1456 new_impl.data, 

1457 dims=self.dims, 

1458 implementation=new_impl.impl_type, 

1459 ) 

1460 

1461 def imag(self): 

1462 """Element-wise imaginary part. 

1463 

1464 Returns: 

1465 A new ``Qarray`` containing the imaginary parts of each element. 

1466 """ 

1467 new_impl = self._impl.imag() 

1468 

1469 return Qarray.create( 

1470 new_impl.data, 

1471 dims=self.dims, 

1472 implementation=new_impl.impl_type, 

1473 ) 

1474 

1475 def conj(self): 

1476 """Element-wise complex conjugate. 

1477 

1478 Returns: 

1479 A new ``Qarray`` containing the complex-conjugated elements. 

1480 """ 

1481 new_impl = self._impl.conj() 

1482 return Qarray.create( 

1483 new_impl.data, 

1484 dims=self.dims, 

1485 implementation=new_impl.impl_type, 

1486 ) 

1487 

1488 def expm(self): 

1489 """Matrix exponential.""" 

1490 return expm(self) 

1491 

1492 def powm(self, n): 

1493 """Matrix power. 

1494 

1495 Args: 

1496 n: Exponent (integer or float). 

1497 

1498 Returns: 

1499 This array raised to the *n*-th matrix power. 

1500 """ 

1501 return powm(self, n) 

1502 

1503 def cosm(self): 

1504 """Matrix cosine.""" 

1505 return cosm(self) 

1506 

1507 def sinm(self): 

1508 """Matrix sine.""" 

1509 return sinm(self) 

1510 

1511 def tr(self, **kwargs): 

1512 """Full trace.""" 

1513 return tr(self, **kwargs) 

1514 

1515 def trace(self, **kwargs): 

1516 """Full trace (alias for :meth:`tr`).""" 

1517 return tr(self, **kwargs) 

1518 

1519 def ptrace(self, indx): 

1520 """Partial trace over subsystem *indx*. 

1521 

1522 Args: 

1523 indx: Index of the subsystem to trace out. 

1524 

1525 Returns: 

1526 Reduced density matrix. 

1527 """ 

1528 return ptrace(self, indx) 

1529 

1530 def eigenstates(self): 

1531 """Eigenvalues and eigenstates of this operator.""" 

1532 return eigenstates(self) 

1533 

1534 def eigenenergies(self): 

1535 """Eigenvalues of this operator.""" 

1536 return eigenenergies(self) 

1537 

1538 def eigenvalues(self): 

1539 """Eigenvalues of this operator (alias for :meth:`eigenenergies`).""" 

1540 return eigenenergies(self) 

1541 

1542 def collapse(self, mode="sum"): 

1543 """Collapse batch dimensions. 

1544 

1545 Args: 

1546 mode: Collapse strategy — currently only ``"sum"`` is supported. 

1547 

1548 Returns: 

1549 A non-batched ``Qarray``. 

1550 """ 

1551 return collapse(self, mode=mode) 

1552 

1553 

1554# Qarray operations --------------------------------------------------------------------- 

1555 

1556def concatenate(qarr_list: List[Qarray], axis: int = 0) -> Qarray: 

1557 """Concatenate a list of Qarrays along a specified axis. 

1558 

1559 Args: 

1560 qarr_list: List of Qarrays to concatenate. 

1561 axis: Axis along which to concatenate. Default is 0. 

1562 

1563 Returns: 

1564 Concatenated Qarray. 

1565 """ 

1566 

1567 non_empty_qarr_list = [qarr for qarr in qarr_list if len(qarr.data) != 0] 

1568 

1569 if len(non_empty_qarr_list) == 0: 

1570 return Qarray.from_list([]) 

1571 

1572 concatenated_data = jnp.concatenate( 

1573 [qarr.data for qarr in non_empty_qarr_list], axis=axis 

1574 ) 

1575 

1576 dims = non_empty_qarr_list[0].dims 

1577 return Qarray.create(concatenated_data, dims=dims) 

1578 

1579 

1580def collapse(qarr: Qarray, mode="sum") -> Qarray: 

1581 """Collapse the batch dimensions of *qarr*. 

1582 

1583 Args: 

1584 qarr: Quantum array with optional batch dimensions. 

1585 mode: Collapse strategy. Only ``"sum"`` is currently supported. 

1586 

1587 Returns: 

1588 A non-batched ``Qarray`` obtained by summing over all batch axes. 

1589 """ 

1590 

1591 if mode == "sum": 

1592 if len(qarr.bdims) == 0: 

1593 return qarr 

1594 

1595 batch_axes = list(range(len(qarr.bdims))) 

1596 

1597 # Preserve implementation type 

1598 implementation = qarr.impl_type 

1599 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims, implementation=implementation) 

1600 

1601 

1602def transpose(qarr: Qarray, indices: List[int]) -> Qarray: 

1603 """Transpose subsystem indices of the quantum array. 

1604 

1605 Args: 

1606 qarr: Input quantum array. 

1607 indices: New ordering of subsystem indices. 

1608 

1609 Returns: 

1610 Transposed ``Qarray`` (converted to dense first). 

1611 """ 

1612 

1613 qarr = qarr.to_dense() 

1614 

1615 indices = list(indices) 

1616 

1617 shaped_data = qarr.shaped_data 

1618 dims = qarr.dims 

1619 bdims_indxs = list(range(len(qarr.bdims))) 

1620 

1621 reshape_indices = indices + [j + len(dims[0]) for j in indices] 

1622 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices] 

1623 

1624 shaped_data = shaped_data.transpose(reshape_indices) 

1625 new_dims = ( 

1626 tuple([dims[0][j] for j in indices]), 

1627 tuple([dims[1][j] for j in indices]), 

1628 ) 

1629 

1630 full_dims = prod(dims[0]) 

1631 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1) 

1632 

1633 # Preserve implementation type 

1634 implementation = qarr.impl_type 

1635 return Qarray.create(full_data, dims=new_dims, implementation=implementation) 

1636 

1637 

1638def unit(qarr: Qarray) -> Qarray: 

1639 """Normalize *qarr* to unit norm. 

1640 

1641 Args: 

1642 qarr: Input quantum array. 

1643 

1644 Returns: 

1645 Normalized quantum array. 

1646 """ 

1647 return qarr / qarr.norm() 

1648 

1649 

1650def norm(qarr: Qarray) -> float: 

1651 """Compute the norm of a quantum array. 

1652 

1653 Sparse paths (no densification): 

1654 

1655 * ket / bra — L2 norm via :meth:`SparseBCOOImpl.l2_norm_batched` (handles 

1656 batch dimensions). 

1657 * operator — trace norm assuming PSD (nuclear norm = tr(rho) for density 

1658 matrices). This is exact for density matrices; for general non-PSD 

1659 operators convert to dense first. 

1660 

1661 Args: 

1662 qarr: Input quantum array. 

1663 

1664 Returns: 

1665 The norm as a scalar (or batched array of scalars). 

1666 """ 

1667 if qarr.qtype in [Qtypes.ket, Qtypes.bra] and qarr.is_sparse_bcoo: 

1668 return qarr._impl.l2_norm_batched(qarr.bdims) 

1669 

1670 if qarr.qtype == Qtypes.oper and qarr.is_sparse_bcoo: 

1671 # Nuclear norm = trace for positive-semidefinite (density matrix) operators. 

1672 # jnp.real strips any floating-point imaginary artefact. 

1673 return jnp.real(qarr._impl.trace()) 

1674 

1675 if qarr.qtype == Qtypes.oper and qarr.is_sparse_dia: 

1676 return jnp.real(qarr._impl.trace()) 

1677 

1678 qarr = qarr.to_dense() 

1679 

1680 qdata = qarr.data 

1681 bdims = qarr.bdims 

1682 

1683 if qarr.qtype == Qtypes.oper: 

1684 qdata_dag = qarr.dag().data 

1685 

1686 if len(bdims) > 0: 

1687 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1]) 

1688 qdata_dag = qdata_dag.reshape(-1, qdata_dag.shape[-2], qdata_dag.shape[-1]) 

1689 

1690 evals, _ = vmap(jnp.linalg.eigh)(qdata @ qdata_dag) 

1691 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)), axis=-1) 

1692 rho_norm = rho_norm.reshape(*bdims) 

1693 return rho_norm 

1694 else: 

1695 evals, _ = jnp.linalg.eigh(qdata @ qdata_dag) 

1696 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals))) 

1697 return rho_norm 

1698 

1699 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]: 

1700 if len(bdims) > 0: 

1701 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1]) 

1702 return vmap(jnp.linalg.norm)(qdata).reshape(*bdims) 

1703 else: 

1704 return jnp.linalg.norm(qdata) 

1705 

1706 

1707def tensor(*args, **kwargs) -> Qarray: 

1708 """Tensor (Kronecker) product of two or more ``Qarray`` objects. 

1709 

1710 Args: 

1711 *args: ``Qarray`` objects to tensor together (left to right). 

1712 **kwargs: Optional keyword arguments. Pass ``parallel=True`` to use 

1713 an einsum-based batched outer product instead of ``jnp.kron``. 

1714 

1715 Returns: 

1716 The tensor product as a ``Qarray``. The output implementation is 

1717 determined by the highest ``PROMOTION_ORDER`` among the inputs: all-sparse 

1718 inputs → sparse output; any dense input → dense output. This holds for 

1719 both ``parallel=True`` and ``parallel=False``. 

1720 

1721 Note: 

1722 ``parallel=True`` uses an einsum-based batched outer product. The 

1723 einsum is always computed on dense data for efficiency, but the result 

1724 is then wrapped in the appropriate backend (sparse when all inputs are 

1725 sparse, dense otherwise). For the default (``parallel=False``) path 

1726 each backend's ``kron`` method is used directly. 

1727 """ 

1728 parallel = kwargs.pop("parallel", False) 

1729 

1730 if parallel: 

1731 # Determine target implementation: highest PROMOTION_ORDER wins. 

1732 # All-sparse → sparse; any dense input → dense (same rule as non-parallel). 

1733 target_impl_type = max( 

1734 (arg.impl_type for arg in args), 

1735 key=lambda t: t.get_impl_class().PROMOTION_ORDER, 

1736 ) 

1737 # Einsum-based batched outer product (computed on dense data). 

1738 dense_args = [arg.to_dense() for arg in args] 

1739 data = dense_args[0].data 

1740 dims_0 = dense_args[0].dims[0] 

1741 dims_1 = dense_args[0].dims[1] 

1742 for arg in dense_args[1:]: 

1743 a, b = data, arg.data 

1744 if len(a.shape) > len(b.shape): 

1745 batch_dim = a.shape[:-2] 

1746 elif len(a.shape) == len(b.shape): 

1747 batch_dim = a.shape[:-2] if prod(a.shape[:-2]) > prod(b.shape[:-2]) else b.shape[:-2] 

1748 else: 

1749 batch_dim = b.shape[:-2] 

1750 

1751 # NOTE: implementation einsum should be used when available 

1752 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape( 

1753 *batch_dim, a.shape[-2] * b.shape[-2], -1 

1754 ) 

1755 dims_0 = dims_0 + arg.dims[0] 

1756 dims_1 = dims_1 + arg.dims[1] 

1757 return Qarray.create(data, dims=(dims_0, dims_1), implementation=target_impl_type) 

1758 

1759 # Non-parallel: delegate to each impl's kron method. 

1760 # All-sparse inputs stay sparse; mixed inputs promote to dense via _coerce. 

1761 current_impl = args[0]._impl 

1762 dims_0 = args[0].dims[0] 

1763 dims_1 = args[0].dims[1] 

1764 for arg in args[1:]: 

1765 current_impl = current_impl.kron(arg._impl) 

1766 dims_0 = dims_0 + arg.dims[0] 

1767 dims_1 = dims_1 + arg.dims[1] 

1768 return Qarray.create(current_impl.data, dims=(dims_0, dims_1), 

1769 implementation=current_impl.impl_type) 

1770 

1771 

1772def tr(qarr: Qarray, **kwargs) -> Array: 

1773 """Full trace of *qarr*. 

1774 

1775 For sparse ``Qarray`` objects the trace is computed natively on the BCOO 

1776 data using a masked scatter — no densification. Custom axis arguments 

1777 are ignored for sparse (the last two dimensions are always the matrix 

1778 dimensions in jaxquantum's convention). 

1779 

1780 Args: 

1781 qarr: Input quantum array. 

1782 **kwargs: Forwarded to ``jnp.trace`` for dense arrays (e.g. 

1783 ``axis1``, ``axis2``). 

1784 

1785 Returns: 

1786 The trace as a scalar (or batched array of scalars). 

1787 """ 

1788 if qarr.is_sparse_bcoo: 

1789 return qarr._impl.trace() 

1790 if qarr.is_sparse_dia: 

1791 return qarr._impl.trace() 

1792 axis1 = kwargs.get("axis1", -2) 

1793 axis2 = kwargs.get("axis2", -1) 

1794 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs) 

1795 

1796 

1797def trace(qarr: Qarray, **kwargs) -> Array: 

1798 """Full trace (alias for :func:`tr`). 

1799 

1800 Args: 

1801 qarr: Input quantum array. 

1802 **kwargs: Forwarded to :func:`tr`. 

1803 

1804 Returns: 

1805 The trace as a scalar (or batched array of scalars). 

1806 """ 

1807 return tr(qarr, **kwargs) 

1808 

1809 

1810def expm_data(data: Array, **kwargs) -> Array: 

1811 """Matrix exponential of a raw array. 

1812 

1813 Args: 

1814 data: Dense matrix array. 

1815 **kwargs: Forwarded to ``jsp.linalg.expm``. 

1816 

1817 Returns: 

1818 The matrix exponential. 

1819 """ 

1820 return jsp.linalg.expm(data, **kwargs) 

1821 

1822 

1823def expm(qarr: Qarray, **kwargs) -> Qarray: 

1824 """Matrix exponential of a ``Qarray``. 

1825 

1826 Args: 

1827 qarr: Input quantum array (converted to dense internally). 

1828 **kwargs: Forwarded to ``jsp.linalg.expm``. 

1829 

1830 Returns: 

1831 A dense ``Qarray`` containing the matrix exponential. 

1832 """ 

1833 dims = qarr.dims 

1834 # Convert to dense for expm 

1835 dense_data = qarr.to_dense().data 

1836 data = expm_data(dense_data, **kwargs) 

1837 return Qarray.create(data, dims=dims) 

1838 

1839 

1840def powm(qarr: Qarray, n: Union[int, float], clip_eigvals=False) -> Qarray: 

1841 """Matrix power of a ``Qarray``. 

1842 

1843 Args: 

1844 qarr: Input quantum array. 

1845 n: Exponent. Integer powers use ``jnp.linalg.matrix_power``; float 

1846 powers diagonalise the matrix. 

1847 clip_eigvals: When ``True``, clip negative eigenvalues to zero before 

1848 applying the float power (useful for nearly-PSD matrices). 

1849 

1850 Returns: 

1851 The *n*-th matrix power as a ``Qarray`` (stays SparseDIA for integer 

1852 non-negative exponents when the input is SparseDIA). 

1853 

1854 Raises: 

1855 ValueError: If *n* is a float and the matrix has negative eigenvalues 

1856 (and *clip_eigvals* is ``False``). 

1857 """ 

1858 # SparseDIA fast path: binary exponentiation stays in SparseDIA format. 

1859 if qarr.is_sparse_dia and isinstance(n, int) and n >= 0: 

1860 new_impl = qarr._impl.powm(n) 

1861 return Qarray.create(new_impl.data, dims=qarr.dims, implementation=new_impl.impl_type) 

1862 

1863 # Convert to dense for powm 

1864 dense_qarr = qarr.to_dense() 

1865 

1866 if isinstance(n, int): 

1867 data_res = jnp.linalg.matrix_power(dense_qarr.data, n) 

1868 else: 

1869 evalues, evectors = jnp.linalg.eig(dense_qarr.data) 

1870 if clip_eigvals: 

1871 evalues = jnp.maximum(evalues, 0) 

1872 else: 

1873 if not (evalues >= 0).all(): 

1874 raise ValueError( 

1875 "Non-integer power of a matrix can only be " 

1876 "computed if the matrix is positive semi-definite." 

1877 "Got a matrix with a negative eigenvalue." 

1878 ) 

1879 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors) 

1880 

1881 return Qarray.create(data_res, dims=qarr.dims) 

1882 

1883 

1884def cosm_data(data: Array, **kwargs) -> Array: 

1885 """Matrix cosine of a raw array. 

1886 

1887 Args: 

1888 data: Dense matrix array. 

1889 **kwargs: Unused; kept for API consistency. 

1890 

1891 Returns: 

1892 The matrix cosine computed as ``(expm(i*A) + expm(-i*A)) / 2``. 

1893 """ 

1894 return (expm_data(1j * data) + expm_data(-1j * data)) / 2 

1895 

1896 

1897def cosm(qarr: Qarray) -> Qarray: 

1898 """Matrix cosine of a ``Qarray``. 

1899 

1900 Args: 

1901 qarr: Input quantum array (converted to dense internally). 

1902 

1903 Returns: 

1904 A dense ``Qarray`` containing the matrix cosine. 

1905 """ 

1906 dims = qarr.dims 

1907 # Convert to dense for cosm 

1908 dense_data = qarr.to_dense().data 

1909 data = cosm_data(dense_data) 

1910 return Qarray.create(data, dims=dims) 

1911 

1912 

1913def sinm_data(data: Array, **kwargs) -> Array: 

1914 """Matrix sine of a raw array. 

1915 

1916 Args: 

1917 data: Dense matrix array. 

1918 **kwargs: Unused; kept for API consistency. 

1919 

1920 Returns: 

1921 The matrix sine computed as ``(expm(i*A) - expm(-i*A)) / (2i)``. 

1922 """ 

1923 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j) 

1924 

1925 

1926def sinm(qarr: Qarray) -> Qarray: 

1927 """Matrix sine of a ``Qarray``. 

1928 

1929 Args: 

1930 qarr: Input quantum array (converted to dense internally). 

1931 

1932 Returns: 

1933 A dense ``Qarray`` containing the matrix sine. 

1934 """ 

1935 dims = qarr.dims 

1936 # Convert to dense for sinm 

1937 dense_data = qarr.to_dense().data 

1938 data = sinm_data(dense_data) 

1939 return Qarray.create(data, dims=dims) 

1940 

1941 

1942def keep_only_diag_elements(qarr: Qarray) -> Qarray: 

1943 """Zero out all off-diagonal elements of *qarr*. 

1944 

1945 For sparse ``Qarray`` objects the off-diagonal stored values are zeroed 

1946 in-place on the BCOO structure — no densification. 

1947 

1948 Args: 

1949 qarr: Non-batched input quantum array. 

1950 

1951 Returns: 

1952 A ``Qarray`` with only diagonal entries non-zero. 

1953 

1954 Raises: 

1955 ValueError: If *qarr* has batch dimensions. 

1956 """ 

1957 if len(qarr.bdims) > 0: 

1958 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.") 

1959 

1960 dims = qarr.dims 

1961 if qarr.is_sparse_bcoo: 

1962 new_impl = qarr._impl.keep_only_diag() 

1963 return Qarray.create(new_impl.data, dims=dims, implementation=QarrayImplType.SPARSE_BCOO) 

1964 if qarr.is_sparse_dia: 

1965 from jaxquantum.core.sparse_dia import SparseDiaImpl 

1966 impl = qarr._impl 

1967 n = impl._diags.shape[-1] 

1968 if 0 in impl._offsets: 

1969 i = impl._offsets.index(0) 

1970 main_diag = impl._diags[..., i:i + 1, :] 

1971 else: 

1972 main_diag = jnp.zeros((*impl._diags.shape[:-2], 1, n), dtype=impl._diags.dtype) 

1973 new_impl = SparseDiaImpl(_offsets=(0,), _diags=main_diag) 

1974 return Qarray.create(new_impl.get_data(), dims=dims, implementation=QarrayImplType.SPARSE_DIA) 

1975 data = jnp.diag(jnp.diag(qarr.data)) 

1976 return Qarray.create(data, dims=dims) 

1977 

1978 

1979def to_ket(qarr: Qarray) -> Qarray: 

1980 """Convert *qarr* to a ket. 

1981 

1982 Args: 

1983 qarr: A ket (returned as-is) or bra (conjugate-transposed). 

1984 

1985 Returns: 

1986 The ket form of *qarr*. 

1987 

1988 Raises: 

1989 ValueError: If *qarr* is an operator. 

1990 """ 

1991 if qarr.qtype == Qtypes.ket: 

1992 return qarr 

1993 elif qarr.qtype == Qtypes.bra: 

1994 return qarr.dag() 

1995 else: 

1996 raise ValueError("Can only get ket from a ket or bra.") 

1997 

1998 

1999def eigenstates(qarr: Qarray) -> Qarray: 

2000 """Eigenstates of a quantum array. 

2001 

2002 Args: 

2003 qarr: Hermitian operator (converted to dense internally). 

2004 

2005 Returns: 

2006 A tuple ``(eigenvalues, eigenstates_qarray)`` where eigenvalues are 

2007 sorted in ascending order. 

2008 """ 

2009 # Convert to dense for eigenstates 

2010 dense_qarr = qarr.to_dense() 

2011 

2012 evals, evecs = jnp.linalg.eigh(dense_qarr.data) 

2013 idxs_sorted = jnp.argsort(evals, axis=-1) 

2014 

2015 dims = ket_from_op_dims(qarr.dims) 

2016 

2017 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1) 

2018 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1) 

2019 

2020 # numpy returns [batch, :, i] as the i-th eigenvector 

2021 # we want [batch, i, :] as the i-th eigenvector 

2022 evecs = jnp.swapaxes(evecs, -2, -1) 

2023 

2024 evecs = Qarray.create( 

2025 evecs, 

2026 dims=dims, 

2027 bdims=evecs.shape[:-1], 

2028 ) 

2029 

2030 return evals, evecs 

2031 

2032 

2033def eigenenergies(qarr: Qarray) -> Array: 

2034 """Eigenvalues of a quantum array. 

2035 

2036 Args: 

2037 qarr: Hermitian operator (converted to dense internally). 

2038 

2039 Returns: 

2040 Sorted eigenvalues as a JAX array. 

2041 """ 

2042 # Convert to dense for eigenenergies 

2043 dense_qarr = qarr.to_dense() 

2044 evals = jnp.linalg.eigvalsh(dense_qarr.data) 

2045 return evals 

2046 

2047 

2048def ptrace(qarr: Qarray, indx) -> Qarray: 

2049 """Partial trace over subsystem *indx*. 

2050 

2051 Args: 

2052 qarr: Input quantum array (converted to dense internally). 

2053 indx: Index of the subsystem to trace out. 

2054 

2055 Returns: 

2056 Reduced density matrix as a ``Qarray``. 

2057 """ 

2058 # Convert to dense for ptrace 

2059 dense_qarr = qarr.to_dense() 

2060 dense_qarr = ket2dm(dense_qarr) 

2061 rho = dense_qarr.shaped_data 

2062 dims = dense_qarr.dims 

2063 

2064 Nq = len(dims[0]) 

2065 

2066 indxs = [indx, indx + Nq] 

2067 for j in range(Nq): 

2068 if j == indx: 

2069 continue 

2070 indxs.append(j) 

2071 indxs.append(j + Nq) 

2072 

2073 bdims = dense_qarr.bdims 

2074 len_bdims = len(bdims) 

2075 bdims_indxs = list(range(len_bdims)) 

2076 indxs = bdims_indxs + [j + len_bdims for j in indxs] 

2077 rho = rho.transpose(indxs) 

2078 

2079 for j in range(Nq - 1): 

2080 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims) 

2081 

2082 return Qarray.create(rho) 

2083 

2084 

2085def dag(qarr: Qarray) -> Qarray: 

2086 """Conjugate transpose of *qarr*. 

2087 

2088 Args: 

2089 qarr: Input quantum array. 

2090 

2091 Returns: 

2092 The conjugate transpose with swapped ``dims``. 

2093 """ 

2094 dims = qarr.dims[::-1] 

2095 new_impl = qarr._impl.dag() 

2096 return Qarray.create( 

2097 new_impl.data, 

2098 dims=dims, 

2099 implementation=new_impl.impl_type, 

2100 ) 

2101 

2102 

2103def dag_data(arr) -> Array: 

2104 """Conjugate transpose of a raw array, dispatching to the right backend. 

2105 

2106 Iterates through registered :class:`QarrayImpl` subclasses and delegates 

2107 to the first one whose :meth:`~QarrayImpl.can_handle_data` returns True. 

2108 Adding a new backend automatically extends this function — no changes 

2109 required here. 

2110 

2111 Args: 

2112 arr: Input array (``jnp.ndarray``, ``sparse.BCOO``, or any type 

2113 handled by a registered impl). For 1-D dense arrays only 

2114 conjugation is applied (no transpose). 

2115 

2116 Returns: 

2117 Conjugate transpose with the last two axes swapped. 

2118 

2119 Raises: 

2120 TypeError: If no registered impl can handle *arr*. 

2121 """ 

2122 for impl_class in _IMPL_REGISTRY: 

2123 if impl_class.can_handle_data(arr): 

2124 return impl_class.dag_data(arr) 

2125 raise TypeError(f"dag_data: no registered impl can handle type {type(arr)}") 

2126 

2127 

2128def ket2dm(qarr: Qarray) -> Qarray: 

2129 """Convert a ket to a density matrix via outer product. 

2130 

2131 Args: 

2132 qarr: Ket, bra, or operator. Operators are returned unchanged. 

2133 

2134 Returns: 

2135 Density matrix ``|ψ⟩⟨ψ|``. 

2136 """ 

2137 if qarr.qtype == Qtypes.oper: 

2138 return qarr 

2139 

2140 if qarr.qtype == Qtypes.bra: 

2141 qarr = qarr.dag() 

2142 

2143 return qarr @ qarr.dag() 

2144 

2145 

2146# Data level operations 

2147def is_dm_data(data: Array) -> bool: 

2148 """Check whether *data* has the shape of a density matrix (square matrix). 

2149 

2150 Args: 

2151 data: Array to check. 

2152 

2153 Returns: 

2154 True if the last two dimensions are equal. 

2155 """ 

2156 return data.shape[-2] == data.shape[-1] 

2157 

2158 

2159def powm_data(data: Array, n: int) -> Array: 

2160 """Integer matrix power of a raw array. 

2161 

2162 Args: 

2163 data: Dense square matrix array. 

2164 n: Integer exponent. 

2165 

2166 Returns: 

2167 The *n*-th matrix power. 

2168 """ 

2169 return jnp.linalg.matrix_power(data, n) 

2170 

2171 

2172# Type aliases for readability 

2173DenseQarray = Qarray[DenseImpl] 

2174# SparseBCOOQarray and SparseDIAQarray are defined lazily (impls imported at runtime) 

2175# Use Qarray[SparseBCOOImpl] / Qarray[SparseDiaImpl] once those modules are imported. 

2176 

2177ARRAY_TYPES = (Array, ndarray, Qarray)