Coverage for jaxquantum / core / qarray.py: 81%
774 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""New Qarray implementation with sparse support."""
3from __future__ import annotations
5from abc import ABC, abstractmethod
6from flax import struct
7from jax import Array, config, vmap
8from typing import TYPE_CHECKING, List, Union, TypeVar, Generic, overload, Literal
10if TYPE_CHECKING:
11 from jaxquantum.core.sparse_bcoo import SparseBCOOImpl
12import jax.numpy as jnp
13import jax.scipy as jsp
14from jax.experimental import sparse
15from numpy import ndarray
16from copy import deepcopy
17from math import prod
18from enum import Enum
20from jaxquantum.core.settings import SETTINGS
21from jaxquantum.utils.utils import robust_isscalar
22from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims
24config.update("jax_enable_x64", True)
26# Type variable for implementation types
27ImplT = TypeVar("ImplT", bound="QarrayImpl")
29# Module-level registry mapping impl_class -> QarrayImplType member
30_IMPL_REGISTRY: dict = {}
33class QarrayImplType(Enum):
34 """Enumeration of available Qarray storage backends.
36 Each member maps one-to-one with a concrete ``QarrayImpl`` subclass.
37 New backends should call ``QarrayImplType.register(MyImpl, QarrayImplType.MY_TYPE)``
38 immediately after defining their impl class.
40 Members:
41 DENSE: Standard JAX dense array (``jnp.ndarray``).
42 SPARSE_BCOO: JAX experimental BCOO sparse array.
43 SPARSE_DIA: Diagonal sparse array.
44 """
46 DENSE = "dense"
47 SPARSE_BCOO = "sparse_bcoo"
48 SPARSE_DIA = "sparse_dia"
50 @classmethod
51 def register(cls, impl_class, member):
52 """Register an implementation class with a QarrayImplType member.
54 Args:
55 impl_class: The concrete ``QarrayImpl`` subclass to register.
56 member: The ``QarrayImplType`` enum member to associate with it.
57 """
58 _IMPL_REGISTRY[impl_class] = member
60 @classmethod
61 def has(cls, x) -> bool:
62 """Return True if x corresponds to a member of QarrayImplType.
64 Accepts an existing ``QarrayImplType`` member, a string equal to the
65 member name or value (case-insensitive), or an implementation class
66 (e.g. ``DenseImpl``, ``SparseBCOOImpl``) that has been registered.
68 Args:
69 x: Value to test — a ``QarrayImplType``, ``str``, or impl class.
71 Returns:
72 True if ``x`` maps to a known ``QarrayImplType`` member.
73 """
74 if isinstance(x, cls):
75 return True
77 if isinstance(x, str):
78 xl = x.lower()
79 return any(xl == member.value or xl == member.name.lower() for member in cls)
81 # Try mapping from an implementation class to an enum member
82 try:
83 cls.from_impl_class(x)
84 return True
85 except Exception:
86 return False
88 @classmethod
89 def from_impl_class(cls, impl_class) -> "QarrayImplType":
90 """Return the ``QarrayImplType`` member associated with *impl_class*.
92 Args:
93 impl_class: A concrete ``QarrayImpl`` subclass that has been
94 registered via :meth:`register`.
96 Returns:
97 The corresponding ``QarrayImplType`` member.
99 Raises:
100 ValueError: If *impl_class* is not in the registry.
101 """
102 if impl_class in _IMPL_REGISTRY:
103 return _IMPL_REGISTRY[impl_class]
104 raise ValueError(f"Unknown implementation class: {impl_class}")
106 def get_impl_class(self):
107 """Return the implementation class registered for this member.
109 Returns:
110 The concrete ``QarrayImpl`` subclass associated with this member.
112 Raises:
113 ValueError: If no class has been registered for this member.
114 """
115 for cls_key, member in _IMPL_REGISTRY.items():
116 if member is self:
117 return cls_key
118 raise ValueError(f"No impl class registered for {self}")
121def robust_asarray(data) -> Union[Array, sparse.BCOO]:
122 """Convert *data* to a JAX array, leaving sparse BCOO and SparseDiaData untouched.
124 Args:
125 data: Input data — any array-like, ``sparse.BCOO``, or ``SparseDiaData``.
127 Returns:
128 A ``jax.Array``, ``sparse.BCOO``, or ``SparseDiaData``.
129 """
130 if isinstance(data, sparse.BCOO):
131 return data
132 # SparseDiaData has a ``_is_sparse_dia`` marker; pass it through unchanged
133 if getattr(data, "_is_sparse_dia", False):
134 return data
135 return jnp.asarray(data)
138class QarrayImpl(ABC):
139 """Abstract base class defining the interface every storage backend must implement.
141 A ``QarrayImpl`` wraps a raw data array (dense ``jnp.ndarray`` or sparse
142 ``BCOO``) and provides the mathematical primitives used by ``Qarray``.
143 Concrete subclasses must implement every ``@abstractmethod``.
145 Attributes:
146 PROMOTION_ORDER: Integer priority used by ``_coerce`` to decide which
147 side to promote when operands have different types. Higher means
148 "more general" (``DenseImpl = 1``, ``SparseBCOOImpl = 0``).
149 """
151 PROMOTION_ORDER: int = 0 # override in subclasses; higher = more general
152 # Current hierarchy: SparseDiaImpl=0, SparseBCOOImpl=1, DenseImpl=2
154 @abstractmethod
155 def get_data(self) -> Array:
156 """Return the underlying raw data array."""
157 pass
159 @property
160 def data(self) -> Array:
161 """The underlying raw data array."""
162 return self.get_data()
164 @property
165 def impl_type(self) -> QarrayImplType:
166 """The ``QarrayImplType`` member corresponding to this instance."""
167 return QarrayImplType.from_impl_class(type(self))
169 @classmethod
170 @abstractmethod
171 def from_data(cls, data) -> "QarrayImpl":
172 """Wrap raw data in this impl type.
174 Args:
175 data: Raw array data (dense ``jnp.ndarray`` or ``sparse.BCOO``).
177 Returns:
178 A new instance of this implementation wrapping *data*.
179 """
180 pass
182 @abstractmethod
183 def matmul(self, other: "QarrayImpl") -> "QarrayImpl":
184 """Matrix multiplication with *other*.
186 Args:
187 other: Right-hand operand.
189 Returns:
190 Result of ``self @ other`` as a ``QarrayImpl``.
191 """
192 pass
194 @abstractmethod
195 def add(self, other: "QarrayImpl") -> "QarrayImpl":
196 """Element-wise addition with *other*.
198 Args:
199 other: Right-hand operand.
201 Returns:
202 Result of ``self + other`` as a ``QarrayImpl``.
203 """
204 pass
206 @abstractmethod
207 def sub(self, other: "QarrayImpl") -> "QarrayImpl":
208 """Element-wise subtraction of *other*.
210 Args:
211 other: Right-hand operand.
213 Returns:
214 Result of ``self - other`` as a ``QarrayImpl``.
215 """
216 pass
218 @abstractmethod
219 def mul(self, scalar) -> "QarrayImpl":
220 """Scalar multiplication.
222 Args:
223 scalar: Scalar value to multiply by.
225 Returns:
226 Result of ``scalar * self`` as a ``QarrayImpl``.
227 """
228 pass
230 @abstractmethod
231 def dag(self) -> "QarrayImpl":
232 """Conjugate transpose.
234 Returns:
235 The conjugate transpose of this array as a ``QarrayImpl``.
236 """
237 pass
239 @abstractmethod
240 def to_dense(self) -> "DenseImpl":
241 """Convert to a ``DenseImpl``.
243 Returns:
244 A ``DenseImpl`` wrapping the same data.
245 """
246 pass
248 @abstractmethod
249 def to_sparse_bcoo(self) -> "SparseBCOOImpl":
250 """Convert to a ``SparseBCOOImpl`` (BCOO).
252 Returns:
253 A ``SparseBCOOImpl`` wrapping the same data.
254 """
255 pass
257 def to_sparse_dia(self) -> "QarrayImpl":
258 """Convert to a ``SparseDiaImpl``.
260 Default implementation goes through dense and auto-detects diagonals.
261 Subclasses may override for a more direct path.
263 Returns:
264 A ``SparseDiaImpl`` wrapping the same data.
265 """
266 # Import here to avoid circular imports at module load time
267 from jaxquantum.core.sparse_dia import SparseDiaImpl
268 return SparseDiaImpl.from_data(self.to_dense()._data)
270 @abstractmethod
271 def shape(self) -> tuple:
272 """Shape of the underlying data array.
274 Returns:
275 Tuple of dimension sizes.
276 """
277 pass
279 @abstractmethod
280 def dtype(self):
281 """Data type of the underlying array.
283 Returns:
284 A numpy/JAX dtype object.
285 """
286 pass
288 @abstractmethod
289 def __deepcopy__(self, memo=None):
290 pass
292 @abstractmethod
293 def tidy_up(self, atol):
294 """Zero out values whose magnitude is below *atol*.
296 Args:
297 atol: Absolute tolerance threshold.
299 Returns:
300 A new ``QarrayImpl`` with small values zeroed.
301 """
302 pass
304 @abstractmethod
305 def kron(self, other: "QarrayImpl") -> "QarrayImpl":
306 """Kronecker (tensor) product with another implementation.
308 Args:
309 other: Right-hand operand. Mixed-type pairs are handled by
310 ``_coerce`` — the result has the higher ``PROMOTION_ORDER``
311 type (dense wins over sparse).
313 Returns:
314 A new ``QarrayImpl`` containing the Kronecker product.
315 """
316 pass
318 @classmethod
319 @abstractmethod
320 def _eye_data(cls, n: int, dtype=None):
321 """Create identity matrix data of size n.
323 Args:
324 n: Matrix size.
325 dtype: Optional data type for the identity entries.
327 Returns:
328 Raw identity matrix data in the format appropriate for this impl.
329 """
330 pass
332 @classmethod
333 @abstractmethod
334 def can_handle_data(cls, arr) -> bool:
335 """Return True if *arr* is a raw data type natively handled by this impl.
337 Used by the module-level :func:`dag_data` dispatcher to route raw
338 arrays to the correct backend without any isinstance chain outside the
339 impl classes.
341 Args:
342 arr: Raw array — e.g. ``jnp.ndarray`` for ``DenseImpl`` or
343 ``sparse.BCOO`` for ``SparseBCOOImpl``.
345 Returns:
346 True if this impl can operate on *arr* without conversion.
347 """
348 pass
350 @classmethod
351 @abstractmethod
352 def dag_data(cls, arr):
353 """Conjugate transpose of raw data in this impl's native format.
355 Implementations must handle batched arrays (last two axes are
356 swapped) and must not densify sparse arrays.
358 Args:
359 arr: Raw array in this impl's native format.
361 Returns:
362 Conjugate transpose with the last two axes swapped.
363 """
364 pass
366 def _promote_to(self, target_cls: type) -> "QarrayImpl":
367 """Convert this impl to *target_cls* by passing through dense.
369 Args:
370 target_cls: The destination ``QarrayImpl`` subclass.
372 Returns:
373 An instance of *target_cls* holding equivalent data.
374 """
375 if isinstance(self, target_cls):
376 return self
377 return target_cls.from_data(self.to_dense()._data)
379 def _coerce(self, other: "QarrayImpl") -> "tuple[QarrayImpl, QarrayImpl]":
380 """Coerce *self* and *other* to the same implementation type.
382 The impl type with the higher ``PROMOTION_ORDER`` wins; the other side
383 is promoted via :meth:`_promote_to`.
385 Args:
386 other: The other operand.
388 Returns:
389 A pair ``(a, b)`` of the same ``QarrayImpl`` subclass, suitable
390 for a binary operation.
391 """
392 if type(self) is type(other):
393 return self, other
394 if self.PROMOTION_ORDER >= other.PROMOTION_ORDER:
395 return self, other._promote_to(type(self))
396 return self._promote_to(type(other)), other
399@struct.dataclass
400class DenseImpl(QarrayImpl):
401 """Dense implementation using JAX dense arrays.
403 Attributes:
404 _data: The underlying ``jnp.ndarray``.
405 """
407 _data: Array
409 PROMOTION_ORDER = 2 # noqa: RUF012 — not a struct field; no annotation intentional
411 @classmethod
412 def from_data(cls, data) -> "DenseImpl":
413 """Wrap *data* in a new ``DenseImpl``.
415 Args:
416 data: Array-like input data.
418 Returns:
419 A ``DenseImpl`` wrapping ``robust_asarray(data)``.
420 """
421 return cls(_data=robust_asarray(data))
423 def get_data(self) -> Array:
424 """Return the underlying dense array."""
425 return self._data
427 def matmul(self, other: QarrayImpl) -> QarrayImpl:
428 """Matrix multiply ``self @ other``, coercing types as needed.
430 Args:
431 other: Right-hand operand.
433 Returns:
434 A ``DenseImpl`` containing the matrix product.
435 """
436 a, b = self._coerce(other)
437 if a is not self:
438 return a.matmul(b)
439 return DenseImpl(self._data @ b._data)
441 def add(self, other: QarrayImpl) -> QarrayImpl:
442 """Element-wise addition ``self + other``, coercing types as needed.
444 Args:
445 other: Right-hand operand.
447 Returns:
448 A ``DenseImpl`` containing the sum.
449 """
450 a, b = self._coerce(other)
451 if a is not self:
452 return a.add(b)
453 return DenseImpl(self._data + b._data)
455 def sub(self, other: QarrayImpl) -> QarrayImpl:
456 """Element-wise subtraction ``self - other``, coercing types as needed.
458 Args:
459 other: Right-hand operand.
461 Returns:
462 A ``DenseImpl`` containing the difference.
463 """
464 a, b = self._coerce(other)
465 if a is not self:
466 return a.sub(b)
467 return DenseImpl(self._data - b._data)
469 def mul(self, scalar) -> QarrayImpl:
470 """Scalar multiplication.
472 Args:
473 scalar: Scalar value.
475 Returns:
476 A ``DenseImpl`` with each element multiplied by *scalar*.
477 """
478 return DenseImpl(scalar * self._data)
480 def dag(self) -> QarrayImpl:
481 """Conjugate transpose.
483 Returns:
484 A ``DenseImpl`` containing the conjugate transpose.
485 """
486 return DenseImpl(jnp.moveaxis(jnp.conj(self._data), -1, -2))
488 def to_dense(self) -> "DenseImpl":
489 """Return self (already dense).
491 Returns:
492 This ``DenseImpl`` instance unchanged.
493 """
494 return self
496 def to_sparse_bcoo(self) -> "SparseBCOOImpl":
497 """Convert to a ``SparseBCOOImpl`` via ``BCOO.fromdense``.
499 Returns:
500 A ``SparseBCOOImpl`` wrapping a BCOO conversion of this array.
501 """
502 from jaxquantum.core.sparse_bcoo import SparseBCOOImpl
503 return SparseBCOOImpl(sparse.BCOO.fromdense(self._data))
505 def shape(self) -> tuple:
506 """Shape of the underlying dense array.
508 Returns:
509 Tuple of dimension sizes.
510 """
511 return self._data.shape
513 def dtype(self):
514 """Data type of the underlying dense array.
516 Returns:
517 The dtype of ``_data``.
518 """
519 return self._data.dtype
521 def frobenius_norm(self) -> float:
522 """Compute the Frobenius norm.
524 Returns:
525 The Frobenius norm as a scalar.
526 """
527 return jnp.sqrt(jnp.sum(jnp.abs(self._data) ** 2))
529 def real(self) -> QarrayImpl:
530 """Element-wise real part.
532 Returns:
533 A ``DenseImpl`` containing the real parts.
534 """
535 return DenseImpl(jnp.real(self._data))
537 def imag(self) -> QarrayImpl:
538 """Element-wise imaginary part.
540 Returns:
541 A ``DenseImpl`` containing the imaginary parts.
542 """
543 return DenseImpl(jnp.imag(self._data))
545 def conj(self) -> QarrayImpl:
546 """Element-wise complex conjugate.
548 Returns:
549 A ``DenseImpl`` containing the complex-conjugated values.
550 """
551 return DenseImpl(jnp.conj(self._data))
553 def __deepcopy__(self, memo=None):
554 return DenseImpl(
555 _data=deepcopy(self._data, memo)
556 )
558 def tidy_up(self, atol):
559 """Zero out real/imaginary parts whose magnitude is below *atol*.
561 Args:
562 atol: Absolute tolerance threshold.
564 Returns:
565 A new ``DenseImpl`` with small values zeroed.
566 """
567 data = self._data
568 data_re = jnp.real(data)
569 data_im = jnp.imag(data)
570 data_re_mask = jnp.abs(data_re) > atol
571 data_im_mask = jnp.abs(data_im) > atol
572 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask
574 return DenseImpl(
575 _data=data_new
576 )
578 def kron(self, other: "QarrayImpl") -> "QarrayImpl":
579 """Kronecker product using ``jnp.kron``.
581 Args:
582 other: Right-hand operand.
584 Returns:
585 A ``DenseImpl`` containing the Kronecker product.
586 """
587 a, b = self._coerce(other)
588 if a is not self:
589 return a.kron(b)
590 return DenseImpl(jnp.kron(self._data, b._data))
592 @classmethod
593 def _eye_data(cls, n: int, dtype=None):
594 """Create an ``n x n`` identity matrix as a dense JAX array.
596 Args:
597 n: Matrix size.
598 dtype: Optional data type.
600 Returns:
601 A ``jnp.ndarray`` identity matrix of shape ``(n, n)``.
602 """
603 return jnp.eye(n, dtype=dtype)
605 @classmethod
606 def can_handle_data(cls, arr) -> bool:
607 """Return True for any non-BCOO, non-SparseDIA array.
609 ``SparseDiaData`` objects carry a ``_is_sparse_dia`` marker so we can
610 exclude them without a direct type import (which would be circular).
612 Args:
613 arr: Raw array.
615 Returns:
616 True when *arr* is a plain dense array (not BCOO, not SparseDiaData).
617 """
618 return not isinstance(arr, sparse.BCOO) and not getattr(arr, "_is_sparse_dia", False)
620 @classmethod
621 def dag_data(cls, arr) -> Array:
622 """Conjugate transpose for dense arrays.
624 Swaps the last two axes via :func:`jnp.moveaxis` and conjugates all
625 elements. For 1-D inputs only conjugation is applied.
627 Args:
628 arr: Dense array.
630 Returns:
631 Conjugate transpose with the last two axes swapped.
632 """
633 if len(arr.shape) == 1:
634 return jnp.conj(arr)
635 return jnp.moveaxis(jnp.conj(arr), -1, -2)
638# Register implementation classes with the enum registry
639# SparseBCOOImpl is registered in sparse_bcoo.py after import
640QarrayImplType.register(DenseImpl, QarrayImplType.DENSE)
643@struct.dataclass
644class Qarray(Generic[ImplT]):
645 """Quantum array with a pluggable storage backend.
647 ``Qarray`` wraps a ``QarrayImpl`` together with quantum-mechanical
648 dimension metadata (``_qdims``) and optional batch dimensions
649 (``_bdims``). The default backend is dense (``DenseImpl``); pass
650 ``implementation="sparse_bcoo"`` (or ``QarrayImplType.SPARSE_BCOO``) to
651 store data as a JAX BCOO sparse array.
653 Attributes:
654 _impl: The storage backend holding the raw data.
655 _qdims: Quantum dimension metadata (bra/ket structure, Hilbert space
656 sizes).
657 _bdims: Tuple of batch dimension sizes (empty tuple = non-batched).
659 Example:
660 >>> import jaxquantum as jqt
661 >>> a = jqt.destroy(10, implementation="sparse_bcoo")
662 >>> a.is_sparse_bcoo
663 True
664 """
666 _impl: ImplT
667 _qdims: Qdims = struct.field(pytree_node=False)
668 _bdims: tuple[int] = struct.field(pytree_node=False)
670 # Initialization ----
671 @classmethod
672 @overload
673 def create(cls, data, dims=None, bdims=None, implementation: Literal[QarrayImplType.DENSE] = QarrayImplType.DENSE) -> "Qarray[DenseImpl]":
674 ...
676 @classmethod
677 @overload
678 def create(cls, data, dims=None, bdims=None, implementation: Literal[QarrayImplType.SPARSE_BCOO] = ...) -> "Qarray[SparseBCOOImpl]":
679 ...
681 @classmethod
682 @overload
683 def create(cls, data, dims=None, bdims=None, implementation=...) -> "Qarray[DenseImpl]":
684 ...
686 @classmethod
687 def create(cls, data, dims=None, bdims=None, implementation=QarrayImplType.DENSE):
688 """Create a ``Qarray`` from raw data.
690 Handles shape normalisation, dimension inference, and tidying of small
691 values.
693 Args:
694 data: Input data array (dense array-like or ``sparse.BCOO``).
695 dims: Quantum dimensions as ``((row_dims...), (col_dims...))``.
696 Inferred from *data* shape when ``None``.
697 bdims: Tuple of batch dimension sizes. Inferred from the leading
698 dimensions of *data* when ``None``.
699 implementation: Storage backend — ``QarrayImplType.DENSE``
700 (default) or ``QarrayImplType.SPARSE_BCOO``, or the equivalent
701 string ``"dense"`` / ``"sparse_bcoo"``.
703 Returns:
704 A new ``Qarray`` backed by the requested implementation.
705 """
706 # Step 1: Prepare data ----
707 data = robust_asarray(data)
709 if len(data.shape) == 1 and data.shape[0] > 0:
710 data = data.reshape(data.shape[0], 1)
712 if len(data.shape) >= 2:
713 if data.shape[-2] != data.shape[-1] and not (
714 data.shape[-2] == 1 or data.shape[-1] == 1
715 ):
716 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
718 if bdims is not None:
719 if len(data.shape) - len(bdims) == 1:
720 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
721 # ----
723 # Step 2: Prepare dimensions ----
724 if bdims is None:
725 bdims = tuple(data.shape[:-2])
727 if dims is None:
728 dims = ((data.shape[-2],), (data.shape[-1],))
730 if not isinstance(dims[0], (list, tuple)):
731 # This handles the case where only the hilbert space dimensions are sent in.
732 if data.shape[-1] == 1:
733 dims = (tuple(dims), tuple([1 for _ in dims]))
734 elif data.shape[-2] == 1:
735 dims = (tuple([1 for _ in dims]), tuple(dims))
736 else:
737 dims = (tuple(dims), tuple(dims))
738 else:
739 dims = (tuple(dims[0]), tuple(dims[1]))
741 check_dims(dims, bdims, data.shape)
743 qdims = Qdims(dims)
745 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill.
746 # It increases the compilation time, but only very slightly
747 # increased the runtime of the jit compiled function.
748 # We could instead use this tidy up where we think we need it.
750 impl_class = QarrayImplType(implementation).get_impl_class()
751 impl = impl_class.from_data(data)
752 impl = impl.tidy_up(SETTINGS["auto_tidyup_atol"])
754 return cls(impl, qdims, bdims)
756 @classmethod
757 @overload
758 def from_sparse_bcoo(cls, data, dims=None, bdims=None) -> "Qarray[SparseBCOOImpl]":
759 ...
761 @classmethod
762 def from_sparse_bcoo(cls, data, dims=None, bdims=None):
763 """Create a ``Qarray`` directly from a sparse BCOO array without densifying.
765 Args:
766 data: A ``sparse.BCOO`` or array-like to store as sparse BCOO.
767 dims: Quantum dimensions. Inferred when ``None``.
768 bdims: Batch dimensions. Inferred when ``None``.
770 Returns:
771 A ``Qarray[SparseBCOOImpl]``.
772 """
773 return cls.create(data, dims=dims, bdims=bdims, implementation=QarrayImplType.SPARSE_BCOO)
775 @classmethod
776 def from_sparse_dia(cls, data, dims=None, bdims=None) -> "Qarray":
777 """Create a SparseDIA-backed ``Qarray``.
779 Accepts either a dense array-like (diagonals are auto-detected) or a
780 :class:`~jaxquantum.core.sparse_dia.SparseDiaData` container.
782 Args:
783 data: Dense array of shape (*batch, n, n) or a ``SparseDiaData``.
784 dims: Quantum dimensions ``((row_dims,), (col_dims,))``.
785 bdims: Batch dimension sizes.
787 Returns:
788 A ``Qarray`` backed by ``SparseDiaImpl``.
789 """
790 return cls.create(data, dims=dims, bdims=bdims, implementation=QarrayImplType.SPARSE_DIA)
792 @classmethod
793 @overload
794 def from_list(cls, qarr_list: List["Qarray[DenseImpl]"]) -> "Qarray[DenseImpl]":
795 ...
797 @classmethod
798 @overload
799 def from_list(cls, qarr_list: List["Qarray[SparseBCOOImpl]"]) -> "Qarray[SparseBCOOImpl]":
800 ...
802 @classmethod
803 def from_list(cls, qarr_list: List[Qarray]) -> Qarray:
804 """Create a batched ``Qarray`` from a list of same-shaped ``Qarray`` objects.
806 The output implementation is determined by the element with the highest
807 ``PROMOTION_ORDER``: if all inputs are sparse the result is sparse; if
808 any input is dense (or types are mixed) all inputs are promoted to dense
809 and the result is dense.
811 Args:
812 qarr_list: List of ``Qarray`` objects with identical ``dims`` and
813 ``bdims``. May be empty.
815 Returns:
816 A ``Qarray`` with an extra leading batch dimension of size
817 ``len(qarr_list)``.
819 Raises:
820 ValueError: If the elements have mismatched ``dims`` or ``bdims``.
821 """
822 if len(qarr_list) == 0:
823 dims = ((), ())
824 bdims = (0,)
825 return cls.create(jnp.array([]), dims=dims, bdims=bdims)
827 dims = qarr_list[0].dims
828 bdims = qarr_list[0].bdims
830 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list):
831 raise ValueError("All Qarrays in the list must have the same dimensions.")
833 new_bdims = (len(qarr_list),) + bdims
835 # Pick the target type: highest PROMOTION_ORDER wins (dense beats sparse).
836 target_impl_type = max(
837 (q.impl_type for q in qarr_list),
838 key=lambda t: t.get_impl_class().PROMOTION_ORDER,
839 )
841 if target_impl_type == QarrayImplType.SPARSE_DIA:
842 # All inputs are SparseDIA — batch without densifying.
843 # Compute union of offsets across all operators, then remap each
844 # operator's _diags rows into the union shape and stack.
845 from jaxquantum.core.sparse_dia import SparseDiaData # lazy to avoid circular
846 union_offsets = tuple(sorted(
847 set().union(*[set(q._impl._offsets) for q in qarr_list])
848 ))
849 union_idx = {k: i for i, k in enumerate(union_offsets)}
850 n = qarr_list[0]._impl._diags.shape[-1]
851 dtype = jnp.result_type(*[q._impl._diags.dtype for q in qarr_list])
852 remapped = []
853 for q in qarr_list:
854 row = jnp.zeros((len(union_offsets), n), dtype=dtype)
855 for i_src, k in enumerate(q._impl._offsets):
856 row = row.at[union_idx[k], :].set(q._impl._diags[i_src, :])
857 remapped.append(row)
858 stacked = jnp.stack(remapped, axis=0) # (n_ops, n_union_diags, N)
859 raw = SparseDiaData(offsets=union_offsets, diags=stacked)
860 return cls.create(raw, dims=dims, bdims=new_bdims, implementation=QarrayImplType.SPARSE_DIA)
862 if target_impl_type == QarrayImplType.SPARSE_BCOO:
863 # All inputs are sparse BCOO — stack via dense intermediates then re-sparsify.
864 data = jnp.array([q.data.todense() for q in qarr_list])
865 return cls.create(data, dims=dims, bdims=new_bdims, implementation=QarrayImplType.SPARSE_BCOO)
867 # Target is dense: promote any sparse inputs before stacking.
868 data = jnp.array([q.to_dense().data for q in qarr_list])
869 return cls.create(data, dims=dims, bdims=new_bdims, implementation=QarrayImplType.DENSE)
871 @classmethod
872 @overload
873 def from_array(cls, qarr_arr: "Qarray[DenseImpl]") -> "Qarray[DenseImpl]":
874 ...
876 @classmethod
877 @overload
878 def from_array(cls, qarr_arr: "Qarray[SparseBCOOImpl]") -> "Qarray[SparseBCOOImpl]":
879 ...
881 @classmethod
882 def from_array(cls, qarr_arr) -> Qarray:
883 """Create a ``Qarray`` from a (possibly nested) list of ``Qarray`` objects.
885 Args:
886 qarr_arr: A ``Qarray`` (returned as-is) or a nested list of
887 ``Qarray`` objects.
889 Returns:
890 A ``Qarray`` with batch dimensions matching the nesting structure
891 of *qarr_arr*.
892 """
893 if isinstance(qarr_arr, Qarray):
894 return qarr_arr
896 bdims = ()
897 lvl = qarr_arr
898 while not isinstance(lvl, Qarray):
899 bdims = bdims + (len(lvl),)
900 if len(lvl) > 0:
901 lvl = lvl[0]
902 else:
903 break
905 def flat(lis):
906 flatList = []
907 for element in lis:
908 if type(element) is list:
909 flatList += flat(element)
910 else:
911 flatList.append(element)
912 return flatList
914 qarr_list = flat(qarr_arr)
915 qarr = cls.from_list(qarr_list)
916 qarr = qarr.reshape_bdims(*bdims)
917 return qarr
919 # Properties ----
920 @property
921 def qtype(self):
922 """Quantum type of this array (ket, bra, or operator)."""
923 return self._qdims.qtype
925 @property
926 def dtype(self):
927 """Data type of the underlying storage array."""
928 return self._impl.dtype()
930 @property
931 def dims(self):
932 """Quantum dimensions as ``((row_dims...), (col_dims...))``."""
933 return self._qdims.dims
935 @property
936 def bdims(self):
937 """Tuple of batch dimension sizes (empty tuple = non-batched)."""
938 return self._bdims
940 @property
941 def qdims(self):
942 """The ``Qdims`` metadata object for this array."""
943 return self._qdims
945 @property
946 def space_dims(self):
947 """Hilbert space dimensions for the relevant side (ket row / bra col)."""
948 if self.qtype in [Qtypes.oper, Qtypes.ket]:
949 return self.dims[0]
950 elif self.qtype == Qtypes.bra:
951 return self.dims[1]
952 else:
953 # TODO: not reached for some reason
954 raise ValueError("Unsupported qtype.")
956 @property
957 def data(self):
958 """The raw underlying data (dense ``jnp.ndarray`` or ``sparse.BCOO``)."""
959 return self._impl.data
961 @property
962 def shaped_data(self):
963 """Data reshaped to ``bdims + dims[0] + dims[1]``."""
964 return self.data.reshape(self.bdims + self.dims[0] + self.dims[1])
966 @property
967 def shape(self):
968 """Shape of the underlying data array."""
969 return self.data.shape
971 @property
972 def is_batched(self):
973 """True if this array has one or more batch dimensions."""
974 return len(self.bdims) > 0
976 @property
977 def is_sparse_bcoo(self):
978 """True if the storage backend is ``SparseBCOOImpl`` (BCOO)."""
979 return self._impl.impl_type == QarrayImplType.SPARSE_BCOO
981 @property
982 def is_dense(self):
983 """True if the storage backend is ``DenseImpl``."""
984 return self._impl.impl_type == QarrayImplType.DENSE
986 @property
987 def is_sparse_dia(self):
988 """True if the storage backend is ``SparseDiaImpl``."""
989 return self._impl.impl_type == QarrayImplType.SPARSE_DIA
991 @property
992 def impl_type(self):
993 """The ``QarrayImplType`` member of the current storage backend."""
994 return self._impl.impl_type
996 def to_sparse_bcoo(self) -> "Qarray[SparseBCOOImpl]":
997 """Return a BCOO-sparse-backed copy of this array.
999 If the array is already sparse BCOO, returns self unchanged.
1001 Returns:
1002 A ``Qarray[SparseBCOOImpl]``.
1003 """
1004 if self.is_sparse_bcoo:
1005 return self
1006 new_impl = self._impl.to_sparse_bcoo()
1007 return Qarray(new_impl, self._qdims, self._bdims)
1009 def to_sparse_dia(self) -> "Qarray":
1010 """Return a SparseDIA-backed copy of this array.
1012 If the array is already SparseDIA, returns self unchanged.
1014 Returns:
1015 A ``Qarray[SparseDiaImpl]``.
1016 """
1017 if self.is_sparse_dia:
1018 return self
1019 new_impl = self._impl.to_sparse_dia()
1020 return Qarray(new_impl, self._qdims, self._bdims)
1022 def to_dense(self) -> "Qarray[DenseImpl]":
1023 """Return a dense-backed copy of this array.
1025 If the array is already dense, returns self unchanged.
1027 Returns:
1028 A ``Qarray[DenseImpl]``.
1029 """
1030 if self.is_dense:
1031 return self
1032 new_impl = self._impl.to_dense()
1033 return Qarray(new_impl, self._qdims, self._bdims)
1035 def __getitem__(self, index):
1036 if len(self.bdims) > 0:
1037 return Qarray.create(
1038 self.data[index],
1039 dims=self.dims,
1040 implementation=self.impl_type,
1041 )
1042 else:
1043 raise ValueError("Cannot index a non-batched Qarray.")
1045 def reshape_bdims(self, *args):
1046 """Reshape the batch dimensions of this ``Qarray``.
1048 Args:
1049 *args: New batch dimension sizes.
1051 Returns:
1052 A new ``Qarray`` with the requested batch shape.
1053 """
1054 new_bdims = tuple(args)
1056 if prod(new_bdims) == 0:
1057 new_shape = new_bdims
1058 else:
1059 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,)
1061 # Preserve implementation type
1062 implementation = self.impl_type
1063 return Qarray.create(
1064 self.data.reshape(new_shape),
1065 dims=self.dims,
1066 bdims=new_bdims,
1067 implementation=implementation,
1068 )
1070 def space_to_qdims(self, space_dims: List[int]):
1071 """Convert Hilbert space dimensions to full quantum dims tuple.
1073 Args:
1074 space_dims: Sequence of per-subsystem Hilbert space sizes, or a
1075 full ``((row_dims), (col_dims))`` tuple (returned unchanged).
1077 Returns:
1078 A ``((row_dims...), (col_dims...))`` tuple.
1080 Raises:
1081 ValueError: If ``self.qtype`` is not ket, bra, or oper.
1082 """
1083 if isinstance(space_dims[0], (list, tuple)):
1084 return space_dims
1086 if self.qtype in [Qtypes.oper, Qtypes.ket]:
1087 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))]))
1088 elif self.qtype == Qtypes.bra:
1089 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims))
1090 else:
1091 raise ValueError("Unsupported qtype for space_to_qdims conversion.")
1093 def reshape_qdims(self, *args):
1094 """Reshape the quantum dimensions of the Qarray.
1096 Note that this does not take in qdims but rather the new Hilbert space
1097 dimensions.
1099 Args:
1100 *args: New Hilbert dimensions for the Qarray.
1102 Returns:
1103 Qarray: reshaped Qarray.
1104 """
1106 new_space_dims = tuple(args)
1107 current_space_dims = self.space_dims
1108 assert prod(new_space_dims) == prod(current_space_dims)
1110 new_qdims = self.space_to_qdims(new_space_dims)
1111 new_bdims = self.bdims
1113 # Preserve implementation type
1114 implementation = self.impl_type
1115 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims, implementation=implementation)
1117 def resize(self, new_shape):
1118 """Resize the Qarray to a new shape.
1120 TODO: review and maybe deprecate this method.
1122 Args:
1123 new_shape: Target shape tuple.
1125 Returns:
1126 A new ``Qarray`` with data resized via ``jnp.resize``.
1127 """
1128 dims = self.dims
1129 data = jnp.resize(self.data, new_shape)
1130 # Preserve implementation type
1131 implementation = self.impl_type
1132 return Qarray.create(
1133 data,
1134 dims=dims,
1135 implementation=implementation,
1136 )
1138 def __len__(self):
1139 """Length along the first batch dimension.
1141 Returns:
1142 Size of the leading batch dimension.
1144 Raises:
1145 ValueError: If the array is not batched.
1146 """
1147 if len(self.bdims) > 0:
1148 return self.data.shape[0]
1149 else:
1150 raise ValueError("Cannot get length of a non-batched Qarray.")
1152 def __eq__(self, other):
1153 if not isinstance(other, Qarray):
1154 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.")
1156 if self.dims != other.dims:
1157 return False
1159 if self.bdims != other.bdims:
1160 return False
1162 if self.is_sparse_bcoo and other.is_sparse_bcoo:
1163 # Fast structural path: same sparsity pattern → compare values only (no todense)
1164 if (self.data.indices.shape == other.data.indices.shape
1165 and bool(jnp.all(self.data.indices == other.data.indices))):
1166 return bool(jnp.allclose(self.data.data, other.data.data))
1167 # Different patterns: fall back to dense comparison (unavoidable)
1168 return bool(jnp.all(self.data.todense() == other.data.todense()))
1170 # At least one dense: convert sparse side to dense for comparison
1171 self_data = self.data.todense() if hasattr(self.data, 'todense') else self.data
1172 other_data = other.data.todense() if hasattr(other.data, 'todense') else other.data
1173 return bool(jnp.all(self_data == other_data))
1175 def __ne__(self, other):
1176 return not self.__eq__(other)
1178 # Elementary Math ----
1179 def __matmul__(self, other):
1180 if not isinstance(other, Qarray):
1181 return NotImplemented
1183 _qdims_new = self._qdims @ other._qdims
1184 new_impl = self._impl.matmul(other._impl)
1186 return Qarray.create(
1187 new_impl.data,
1188 dims=_qdims_new.dims,
1189 implementation=new_impl.impl_type,
1190 )
1192 def __mul__(self, other):
1193 if isinstance(other, Qarray):
1194 return self.__matmul__(other)
1196 other = other + 0.0j
1197 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
1198 other = other.reshape(other.shape + (1, 1))
1200 new_impl = self._impl.mul(other)
1201 return Qarray.create(
1202 new_impl.data,
1203 dims=self._qdims.dims,
1204 implementation=new_impl.impl_type,
1205 )
1207 def __rmul__(self, other):
1208 return self.__mul__(other)
1210 def __neg__(self):
1211 return self.__mul__(-1)
1213 def __truediv__(self, other):
1214 """Divide by a scalar.
1216 Args:
1217 other: Scalar divisor.
1219 Returns:
1220 A new ``Qarray`` with all elements divided by *other*.
1222 Raises:
1223 ValueError: If *other* is a ``Qarray``.
1224 """
1225 if isinstance(other, Qarray):
1226 raise ValueError("Cannot divide a Qarray by another Qarray.")
1228 return self.__mul__(1 / other)
1230 def __add__(self, other):
1231 if isinstance(other, Qarray):
1232 if self.dims != other.dims:
1233 msg = (
1234 "Dimensions are incompatible: "
1235 + repr(self.dims)
1236 + " and "
1237 + repr(other.dims)
1238 )
1239 raise ValueError(msg)
1240 new_impl = self._impl.add(other._impl)
1241 return Qarray.create(
1242 new_impl.data,
1243 dims=self.dims,
1244 implementation=new_impl.impl_type,
1245 )
1247 if robust_isscalar(other) and other == 0:
1248 return self.copy()
1250 if self.data.shape[-2] == self.data.shape[-1]:
1251 other = other + 0.0j
1252 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
1253 other = other.reshape(other.shape + (1, 1))
1254 eye_data = self._impl._eye_data(self.data.shape[-2], dtype=self.data.dtype)
1255 other = Qarray.create(
1256 other * eye_data,
1257 dims=self.dims,
1258 implementation=self.impl_type
1259 )
1260 return self.__add__(other)
1262 return NotImplemented
1264 def __radd__(self, other):
1265 return self.__add__(other)
1267 def __sub__(self, other):
1268 if isinstance(other, Qarray):
1269 if self.dims != other.dims:
1270 msg = (
1271 "Dimensions are incompatible: "
1272 + repr(self.dims)
1273 + " and "
1274 + repr(other.dims)
1275 )
1276 raise ValueError(msg)
1277 new_impl = self._impl.sub(other._impl)
1278 return Qarray.create(
1279 new_impl.data,
1280 dims=self.dims,
1281 implementation=new_impl.impl_type,
1282 )
1284 if robust_isscalar(other) and other == 0:
1285 return self.copy()
1287 if self.data.shape[-2] == self.data.shape[-1]:
1288 other = other + 0.0j
1290 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
1291 other = other.reshape(other.shape + (1, 1))
1292 eye_data = self._impl._eye_data(self.data.shape[-2], dtype=self.data.dtype)
1293 other = Qarray.create(
1294 other * eye_data,
1295 dims=self.dims,
1296 implementation=self.impl_type
1297 )
1298 return self.__sub__(other)
1300 return NotImplemented
1302 def __rsub__(self, other):
1303 return self.__neg__().__add__(other)
1305 def __xor__(self, other):
1306 if not isinstance(other, Qarray):
1307 return NotImplemented
1308 return tensor(self, other)
1310 def __rxor__(self, other):
1311 if not isinstance(other, Qarray):
1312 return NotImplemented
1313 return tensor(other, self)
1315 def __pow__(self, other):
1316 if not isinstance(other, int):
1317 return NotImplemented
1319 return powm(self, other)
1321 # String Representation ----
1322 def _str_header(self):
1323 """Build the one-line header string for ``__str__`` and ``__repr__``."""
1324 impl_type = self.impl_type.value
1325 out = ", ".join(
1326 [
1327 "Quantum array: dims = " + str(self.dims),
1328 "bdims = " + str(self.bdims),
1329 "shape = " + str(self.data.shape),
1330 "type = " + str(self.qtype),
1331 "impl = " + impl_type,
1332 ]
1333 )
1334 return out
1336 def __str__(self):
1337 return self._str_header() + "\nQarray data =\n" + str(self.data)
1339 @property
1340 def header(self):
1341 """One-line header string describing dimensions, shape, and backend."""
1342 return self._str_header()
1344 def __repr__(self):
1345 return self.__str__()
1347 # Utilities ----
1348 def copy(self, memo=None):
1349 """Return a deep copy of this ``Qarray``.
1351 Args:
1352 memo: Optional memo dict forwarded to ``deepcopy``.
1354 Returns:
1355 A new ``Qarray`` with independent copies of all data.
1356 """
1357 return self.__deepcopy__(memo)
1359 def __deepcopy__(self, memo):
1360 """Need to override this when defining __getattr__."""
1362 return Qarray(
1363 _impl=deepcopy(self._impl, memo=memo),
1364 _qdims=deepcopy(self._qdims, memo=memo),
1365 _bdims=deepcopy(self._bdims, memo=memo),
1366 )
1368 def __getattr__(self, method_name):
1369 if "__" == method_name[:2]:
1370 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__
1371 return lambda *args, **kwargs: NotImplemented
1373 modules = [jnp, jnp.linalg, jsp, jsp.linalg]
1375 method_f = None
1376 for mod in modules:
1377 method_f = getattr(mod, method_name, None)
1378 if method_f is not None:
1379 break
1381 if method_f is None:
1382 raise NotImplementedError(
1383 f"Method {method_name} does not exist. No backup method found in {modules}."
1384 )
1386 def func(*args, **kwargs):
1387 # For operations that might not be supported in sparse, convert to dense
1388 if self.is_sparse_bcoo:
1389 dense_self = self.to_dense()
1390 res = method_f(dense_self.data, *args, **kwargs)
1391 else:
1392 res = method_f(self.data, *args, **kwargs)
1394 if getattr(res, "shape", None) is None or res.shape != self.data.shape:
1395 return res
1396 else:
1397 # Preserve implementation type
1398 return Qarray.create(res, dims=self._qdims.dims, implementation=self.impl_type)
1400 return func
1402 # Conversions / Reshaping ----
1403 def dag(self):
1404 """Conjugate transpose of this array."""
1405 return dag(self)
1407 def to_dm(self):
1408 """Convert a ket to a density matrix via outer product."""
1409 return ket2dm(self)
1411 def is_dm(self):
1412 """Return True if this array is an operator (density-matrix type)."""
1413 return self.qtype == Qtypes.oper
1415 def is_vec(self):
1416 """Return True if this array is a ket or bra."""
1417 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra
1419 def to_ket(self):
1420 """Convert a bra to a ket (no-op for kets)."""
1421 return to_ket(self)
1423 def transpose(self, *args):
1424 """Transpose subsystem indices."""
1425 return transpose(self, *args)
1427 def keep_only_diag_elements(self):
1428 """Zero out all off-diagonal elements."""
1429 return keep_only_diag_elements(self)
1431 # Math Functions ----
1432 def unit(self):
1433 """Return the normalised (unit-norm) version of this array."""
1434 return unit(self)
1436 def norm(self):
1437 """Compute the norm of this array."""
1438 return norm(self)
1440 def frobenius_norm(self):
1441 """Compute the Frobenius norm directly from the implementation.
1443 Returns:
1444 The Frobenius norm as a scalar.
1445 """
1446 return self._impl.frobenius_norm()
1448 def real(self):
1449 """Element-wise real part.
1451 Returns:
1452 A new ``Qarray`` containing the real parts of each element.
1453 """
1454 new_impl = self._impl.real()
1455 return Qarray.create(
1456 new_impl.data,
1457 dims=self.dims,
1458 implementation=new_impl.impl_type,
1459 )
1461 def imag(self):
1462 """Element-wise imaginary part.
1464 Returns:
1465 A new ``Qarray`` containing the imaginary parts of each element.
1466 """
1467 new_impl = self._impl.imag()
1469 return Qarray.create(
1470 new_impl.data,
1471 dims=self.dims,
1472 implementation=new_impl.impl_type,
1473 )
1475 def conj(self):
1476 """Element-wise complex conjugate.
1478 Returns:
1479 A new ``Qarray`` containing the complex-conjugated elements.
1480 """
1481 new_impl = self._impl.conj()
1482 return Qarray.create(
1483 new_impl.data,
1484 dims=self.dims,
1485 implementation=new_impl.impl_type,
1486 )
1488 def expm(self):
1489 """Matrix exponential."""
1490 return expm(self)
1492 def powm(self, n):
1493 """Matrix power.
1495 Args:
1496 n: Exponent (integer or float).
1498 Returns:
1499 This array raised to the *n*-th matrix power.
1500 """
1501 return powm(self, n)
1503 def cosm(self):
1504 """Matrix cosine."""
1505 return cosm(self)
1507 def sinm(self):
1508 """Matrix sine."""
1509 return sinm(self)
1511 def tr(self, **kwargs):
1512 """Full trace."""
1513 return tr(self, **kwargs)
1515 def trace(self, **kwargs):
1516 """Full trace (alias for :meth:`tr`)."""
1517 return tr(self, **kwargs)
1519 def ptrace(self, indx):
1520 """Partial trace over subsystem *indx*.
1522 Args:
1523 indx: Index of the subsystem to trace out.
1525 Returns:
1526 Reduced density matrix.
1527 """
1528 return ptrace(self, indx)
1530 def eigenstates(self):
1531 """Eigenvalues and eigenstates of this operator."""
1532 return eigenstates(self)
1534 def eigenenergies(self):
1535 """Eigenvalues of this operator."""
1536 return eigenenergies(self)
1538 def eigenvalues(self):
1539 """Eigenvalues of this operator (alias for :meth:`eigenenergies`)."""
1540 return eigenenergies(self)
1542 def collapse(self, mode="sum"):
1543 """Collapse batch dimensions.
1545 Args:
1546 mode: Collapse strategy — currently only ``"sum"`` is supported.
1548 Returns:
1549 A non-batched ``Qarray``.
1550 """
1551 return collapse(self, mode=mode)
1554# Qarray operations ---------------------------------------------------------------------
1556def concatenate(qarr_list: List[Qarray], axis: int = 0) -> Qarray:
1557 """Concatenate a list of Qarrays along a specified axis.
1559 Args:
1560 qarr_list: List of Qarrays to concatenate.
1561 axis: Axis along which to concatenate. Default is 0.
1563 Returns:
1564 Concatenated Qarray.
1565 """
1567 non_empty_qarr_list = [qarr for qarr in qarr_list if len(qarr.data) != 0]
1569 if len(non_empty_qarr_list) == 0:
1570 return Qarray.from_list([])
1572 concatenated_data = jnp.concatenate(
1573 [qarr.data for qarr in non_empty_qarr_list], axis=axis
1574 )
1576 dims = non_empty_qarr_list[0].dims
1577 return Qarray.create(concatenated_data, dims=dims)
1580def collapse(qarr: Qarray, mode="sum") -> Qarray:
1581 """Collapse the batch dimensions of *qarr*.
1583 Args:
1584 qarr: Quantum array with optional batch dimensions.
1585 mode: Collapse strategy. Only ``"sum"`` is currently supported.
1587 Returns:
1588 A non-batched ``Qarray`` obtained by summing over all batch axes.
1589 """
1591 if mode == "sum":
1592 if len(qarr.bdims) == 0:
1593 return qarr
1595 batch_axes = list(range(len(qarr.bdims)))
1597 # Preserve implementation type
1598 implementation = qarr.impl_type
1599 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims, implementation=implementation)
1602def transpose(qarr: Qarray, indices: List[int]) -> Qarray:
1603 """Transpose subsystem indices of the quantum array.
1605 Args:
1606 qarr: Input quantum array.
1607 indices: New ordering of subsystem indices.
1609 Returns:
1610 Transposed ``Qarray`` (converted to dense first).
1611 """
1613 qarr = qarr.to_dense()
1615 indices = list(indices)
1617 shaped_data = qarr.shaped_data
1618 dims = qarr.dims
1619 bdims_indxs = list(range(len(qarr.bdims)))
1621 reshape_indices = indices + [j + len(dims[0]) for j in indices]
1622 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices]
1624 shaped_data = shaped_data.transpose(reshape_indices)
1625 new_dims = (
1626 tuple([dims[0][j] for j in indices]),
1627 tuple([dims[1][j] for j in indices]),
1628 )
1630 full_dims = prod(dims[0])
1631 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1)
1633 # Preserve implementation type
1634 implementation = qarr.impl_type
1635 return Qarray.create(full_data, dims=new_dims, implementation=implementation)
1638def unit(qarr: Qarray) -> Qarray:
1639 """Normalize *qarr* to unit norm.
1641 Args:
1642 qarr: Input quantum array.
1644 Returns:
1645 Normalized quantum array.
1646 """
1647 return qarr / qarr.norm()
1650def norm(qarr: Qarray) -> float:
1651 """Compute the norm of a quantum array.
1653 Sparse paths (no densification):
1655 * ket / bra — L2 norm via :meth:`SparseBCOOImpl.l2_norm_batched` (handles
1656 batch dimensions).
1657 * operator — trace norm assuming PSD (nuclear norm = tr(rho) for density
1658 matrices). This is exact for density matrices; for general non-PSD
1659 operators convert to dense first.
1661 Args:
1662 qarr: Input quantum array.
1664 Returns:
1665 The norm as a scalar (or batched array of scalars).
1666 """
1667 if qarr.qtype in [Qtypes.ket, Qtypes.bra] and qarr.is_sparse_bcoo:
1668 return qarr._impl.l2_norm_batched(qarr.bdims)
1670 if qarr.qtype == Qtypes.oper and qarr.is_sparse_bcoo:
1671 # Nuclear norm = trace for positive-semidefinite (density matrix) operators.
1672 # jnp.real strips any floating-point imaginary artefact.
1673 return jnp.real(qarr._impl.trace())
1675 if qarr.qtype == Qtypes.oper and qarr.is_sparse_dia:
1676 return jnp.real(qarr._impl.trace())
1678 qarr = qarr.to_dense()
1680 qdata = qarr.data
1681 bdims = qarr.bdims
1683 if qarr.qtype == Qtypes.oper:
1684 qdata_dag = qarr.dag().data
1686 if len(bdims) > 0:
1687 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1])
1688 qdata_dag = qdata_dag.reshape(-1, qdata_dag.shape[-2], qdata_dag.shape[-1])
1690 evals, _ = vmap(jnp.linalg.eigh)(qdata @ qdata_dag)
1691 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)), axis=-1)
1692 rho_norm = rho_norm.reshape(*bdims)
1693 return rho_norm
1694 else:
1695 evals, _ = jnp.linalg.eigh(qdata @ qdata_dag)
1696 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)))
1697 return rho_norm
1699 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]:
1700 if len(bdims) > 0:
1701 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1])
1702 return vmap(jnp.linalg.norm)(qdata).reshape(*bdims)
1703 else:
1704 return jnp.linalg.norm(qdata)
1707def tensor(*args, **kwargs) -> Qarray:
1708 """Tensor (Kronecker) product of two or more ``Qarray`` objects.
1710 Args:
1711 *args: ``Qarray`` objects to tensor together (left to right).
1712 **kwargs: Optional keyword arguments. Pass ``parallel=True`` to use
1713 an einsum-based batched outer product instead of ``jnp.kron``.
1715 Returns:
1716 The tensor product as a ``Qarray``. The output implementation is
1717 determined by the highest ``PROMOTION_ORDER`` among the inputs: all-sparse
1718 inputs → sparse output; any dense input → dense output. This holds for
1719 both ``parallel=True`` and ``parallel=False``.
1721 Note:
1722 ``parallel=True`` uses an einsum-based batched outer product. The
1723 einsum is always computed on dense data for efficiency, but the result
1724 is then wrapped in the appropriate backend (sparse when all inputs are
1725 sparse, dense otherwise). For the default (``parallel=False``) path
1726 each backend's ``kron`` method is used directly.
1727 """
1728 parallel = kwargs.pop("parallel", False)
1730 if parallel:
1731 # Determine target implementation: highest PROMOTION_ORDER wins.
1732 # All-sparse → sparse; any dense input → dense (same rule as non-parallel).
1733 target_impl_type = max(
1734 (arg.impl_type for arg in args),
1735 key=lambda t: t.get_impl_class().PROMOTION_ORDER,
1736 )
1737 # Einsum-based batched outer product (computed on dense data).
1738 dense_args = [arg.to_dense() for arg in args]
1739 data = dense_args[0].data
1740 dims_0 = dense_args[0].dims[0]
1741 dims_1 = dense_args[0].dims[1]
1742 for arg in dense_args[1:]:
1743 a, b = data, arg.data
1744 if len(a.shape) > len(b.shape):
1745 batch_dim = a.shape[:-2]
1746 elif len(a.shape) == len(b.shape):
1747 batch_dim = a.shape[:-2] if prod(a.shape[:-2]) > prod(b.shape[:-2]) else b.shape[:-2]
1748 else:
1749 batch_dim = b.shape[:-2]
1751 # NOTE: implementation einsum should be used when available
1752 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape(
1753 *batch_dim, a.shape[-2] * b.shape[-2], -1
1754 )
1755 dims_0 = dims_0 + arg.dims[0]
1756 dims_1 = dims_1 + arg.dims[1]
1757 return Qarray.create(data, dims=(dims_0, dims_1), implementation=target_impl_type)
1759 # Non-parallel: delegate to each impl's kron method.
1760 # All-sparse inputs stay sparse; mixed inputs promote to dense via _coerce.
1761 current_impl = args[0]._impl
1762 dims_0 = args[0].dims[0]
1763 dims_1 = args[0].dims[1]
1764 for arg in args[1:]:
1765 current_impl = current_impl.kron(arg._impl)
1766 dims_0 = dims_0 + arg.dims[0]
1767 dims_1 = dims_1 + arg.dims[1]
1768 return Qarray.create(current_impl.data, dims=(dims_0, dims_1),
1769 implementation=current_impl.impl_type)
1772def tr(qarr: Qarray, **kwargs) -> Array:
1773 """Full trace of *qarr*.
1775 For sparse ``Qarray`` objects the trace is computed natively on the BCOO
1776 data using a masked scatter — no densification. Custom axis arguments
1777 are ignored for sparse (the last two dimensions are always the matrix
1778 dimensions in jaxquantum's convention).
1780 Args:
1781 qarr: Input quantum array.
1782 **kwargs: Forwarded to ``jnp.trace`` for dense arrays (e.g.
1783 ``axis1``, ``axis2``).
1785 Returns:
1786 The trace as a scalar (or batched array of scalars).
1787 """
1788 if qarr.is_sparse_bcoo:
1789 return qarr._impl.trace()
1790 if qarr.is_sparse_dia:
1791 return qarr._impl.trace()
1792 axis1 = kwargs.get("axis1", -2)
1793 axis2 = kwargs.get("axis2", -1)
1794 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs)
1797def trace(qarr: Qarray, **kwargs) -> Array:
1798 """Full trace (alias for :func:`tr`).
1800 Args:
1801 qarr: Input quantum array.
1802 **kwargs: Forwarded to :func:`tr`.
1804 Returns:
1805 The trace as a scalar (or batched array of scalars).
1806 """
1807 return tr(qarr, **kwargs)
1810def expm_data(data: Array, **kwargs) -> Array:
1811 """Matrix exponential of a raw array.
1813 Args:
1814 data: Dense matrix array.
1815 **kwargs: Forwarded to ``jsp.linalg.expm``.
1817 Returns:
1818 The matrix exponential.
1819 """
1820 return jsp.linalg.expm(data, **kwargs)
1823def expm(qarr: Qarray, **kwargs) -> Qarray:
1824 """Matrix exponential of a ``Qarray``.
1826 Args:
1827 qarr: Input quantum array (converted to dense internally).
1828 **kwargs: Forwarded to ``jsp.linalg.expm``.
1830 Returns:
1831 A dense ``Qarray`` containing the matrix exponential.
1832 """
1833 dims = qarr.dims
1834 # Convert to dense for expm
1835 dense_data = qarr.to_dense().data
1836 data = expm_data(dense_data, **kwargs)
1837 return Qarray.create(data, dims=dims)
1840def powm(qarr: Qarray, n: Union[int, float], clip_eigvals=False) -> Qarray:
1841 """Matrix power of a ``Qarray``.
1843 Args:
1844 qarr: Input quantum array.
1845 n: Exponent. Integer powers use ``jnp.linalg.matrix_power``; float
1846 powers diagonalise the matrix.
1847 clip_eigvals: When ``True``, clip negative eigenvalues to zero before
1848 applying the float power (useful for nearly-PSD matrices).
1850 Returns:
1851 The *n*-th matrix power as a ``Qarray`` (stays SparseDIA for integer
1852 non-negative exponents when the input is SparseDIA).
1854 Raises:
1855 ValueError: If *n* is a float and the matrix has negative eigenvalues
1856 (and *clip_eigvals* is ``False``).
1857 """
1858 # SparseDIA fast path: binary exponentiation stays in SparseDIA format.
1859 if qarr.is_sparse_dia and isinstance(n, int) and n >= 0:
1860 new_impl = qarr._impl.powm(n)
1861 return Qarray.create(new_impl.data, dims=qarr.dims, implementation=new_impl.impl_type)
1863 # Convert to dense for powm
1864 dense_qarr = qarr.to_dense()
1866 if isinstance(n, int):
1867 data_res = jnp.linalg.matrix_power(dense_qarr.data, n)
1868 else:
1869 evalues, evectors = jnp.linalg.eig(dense_qarr.data)
1870 if clip_eigvals:
1871 evalues = jnp.maximum(evalues, 0)
1872 else:
1873 if not (evalues >= 0).all():
1874 raise ValueError(
1875 "Non-integer power of a matrix can only be "
1876 "computed if the matrix is positive semi-definite."
1877 "Got a matrix with a negative eigenvalue."
1878 )
1879 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors)
1881 return Qarray.create(data_res, dims=qarr.dims)
1884def cosm_data(data: Array, **kwargs) -> Array:
1885 """Matrix cosine of a raw array.
1887 Args:
1888 data: Dense matrix array.
1889 **kwargs: Unused; kept for API consistency.
1891 Returns:
1892 The matrix cosine computed as ``(expm(i*A) + expm(-i*A)) / 2``.
1893 """
1894 return (expm_data(1j * data) + expm_data(-1j * data)) / 2
1897def cosm(qarr: Qarray) -> Qarray:
1898 """Matrix cosine of a ``Qarray``.
1900 Args:
1901 qarr: Input quantum array (converted to dense internally).
1903 Returns:
1904 A dense ``Qarray`` containing the matrix cosine.
1905 """
1906 dims = qarr.dims
1907 # Convert to dense for cosm
1908 dense_data = qarr.to_dense().data
1909 data = cosm_data(dense_data)
1910 return Qarray.create(data, dims=dims)
1913def sinm_data(data: Array, **kwargs) -> Array:
1914 """Matrix sine of a raw array.
1916 Args:
1917 data: Dense matrix array.
1918 **kwargs: Unused; kept for API consistency.
1920 Returns:
1921 The matrix sine computed as ``(expm(i*A) - expm(-i*A)) / (2i)``.
1922 """
1923 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j)
1926def sinm(qarr: Qarray) -> Qarray:
1927 """Matrix sine of a ``Qarray``.
1929 Args:
1930 qarr: Input quantum array (converted to dense internally).
1932 Returns:
1933 A dense ``Qarray`` containing the matrix sine.
1934 """
1935 dims = qarr.dims
1936 # Convert to dense for sinm
1937 dense_data = qarr.to_dense().data
1938 data = sinm_data(dense_data)
1939 return Qarray.create(data, dims=dims)
1942def keep_only_diag_elements(qarr: Qarray) -> Qarray:
1943 """Zero out all off-diagonal elements of *qarr*.
1945 For sparse ``Qarray`` objects the off-diagonal stored values are zeroed
1946 in-place on the BCOO structure — no densification.
1948 Args:
1949 qarr: Non-batched input quantum array.
1951 Returns:
1952 A ``Qarray`` with only diagonal entries non-zero.
1954 Raises:
1955 ValueError: If *qarr* has batch dimensions.
1956 """
1957 if len(qarr.bdims) > 0:
1958 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.")
1960 dims = qarr.dims
1961 if qarr.is_sparse_bcoo:
1962 new_impl = qarr._impl.keep_only_diag()
1963 return Qarray.create(new_impl.data, dims=dims, implementation=QarrayImplType.SPARSE_BCOO)
1964 if qarr.is_sparse_dia:
1965 from jaxquantum.core.sparse_dia import SparseDiaImpl
1966 impl = qarr._impl
1967 n = impl._diags.shape[-1]
1968 if 0 in impl._offsets:
1969 i = impl._offsets.index(0)
1970 main_diag = impl._diags[..., i:i + 1, :]
1971 else:
1972 main_diag = jnp.zeros((*impl._diags.shape[:-2], 1, n), dtype=impl._diags.dtype)
1973 new_impl = SparseDiaImpl(_offsets=(0,), _diags=main_diag)
1974 return Qarray.create(new_impl.get_data(), dims=dims, implementation=QarrayImplType.SPARSE_DIA)
1975 data = jnp.diag(jnp.diag(qarr.data))
1976 return Qarray.create(data, dims=dims)
1979def to_ket(qarr: Qarray) -> Qarray:
1980 """Convert *qarr* to a ket.
1982 Args:
1983 qarr: A ket (returned as-is) or bra (conjugate-transposed).
1985 Returns:
1986 The ket form of *qarr*.
1988 Raises:
1989 ValueError: If *qarr* is an operator.
1990 """
1991 if qarr.qtype == Qtypes.ket:
1992 return qarr
1993 elif qarr.qtype == Qtypes.bra:
1994 return qarr.dag()
1995 else:
1996 raise ValueError("Can only get ket from a ket or bra.")
1999def eigenstates(qarr: Qarray) -> Qarray:
2000 """Eigenstates of a quantum array.
2002 Args:
2003 qarr: Hermitian operator (converted to dense internally).
2005 Returns:
2006 A tuple ``(eigenvalues, eigenstates_qarray)`` where eigenvalues are
2007 sorted in ascending order.
2008 """
2009 # Convert to dense for eigenstates
2010 dense_qarr = qarr.to_dense()
2012 evals, evecs = jnp.linalg.eigh(dense_qarr.data)
2013 idxs_sorted = jnp.argsort(evals, axis=-1)
2015 dims = ket_from_op_dims(qarr.dims)
2017 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1)
2018 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1)
2020 # numpy returns [batch, :, i] as the i-th eigenvector
2021 # we want [batch, i, :] as the i-th eigenvector
2022 evecs = jnp.swapaxes(evecs, -2, -1)
2024 evecs = Qarray.create(
2025 evecs,
2026 dims=dims,
2027 bdims=evecs.shape[:-1],
2028 )
2030 return evals, evecs
2033def eigenenergies(qarr: Qarray) -> Array:
2034 """Eigenvalues of a quantum array.
2036 Args:
2037 qarr: Hermitian operator (converted to dense internally).
2039 Returns:
2040 Sorted eigenvalues as a JAX array.
2041 """
2042 # Convert to dense for eigenenergies
2043 dense_qarr = qarr.to_dense()
2044 evals = jnp.linalg.eigvalsh(dense_qarr.data)
2045 return evals
2048def ptrace(qarr: Qarray, indx) -> Qarray:
2049 """Partial trace over subsystem *indx*.
2051 Args:
2052 qarr: Input quantum array (converted to dense internally).
2053 indx: Index of the subsystem to trace out.
2055 Returns:
2056 Reduced density matrix as a ``Qarray``.
2057 """
2058 # Convert to dense for ptrace
2059 dense_qarr = qarr.to_dense()
2060 dense_qarr = ket2dm(dense_qarr)
2061 rho = dense_qarr.shaped_data
2062 dims = dense_qarr.dims
2064 Nq = len(dims[0])
2066 indxs = [indx, indx + Nq]
2067 for j in range(Nq):
2068 if j == indx:
2069 continue
2070 indxs.append(j)
2071 indxs.append(j + Nq)
2073 bdims = dense_qarr.bdims
2074 len_bdims = len(bdims)
2075 bdims_indxs = list(range(len_bdims))
2076 indxs = bdims_indxs + [j + len_bdims for j in indxs]
2077 rho = rho.transpose(indxs)
2079 for j in range(Nq - 1):
2080 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims)
2082 return Qarray.create(rho)
2085def dag(qarr: Qarray) -> Qarray:
2086 """Conjugate transpose of *qarr*.
2088 Args:
2089 qarr: Input quantum array.
2091 Returns:
2092 The conjugate transpose with swapped ``dims``.
2093 """
2094 dims = qarr.dims[::-1]
2095 new_impl = qarr._impl.dag()
2096 return Qarray.create(
2097 new_impl.data,
2098 dims=dims,
2099 implementation=new_impl.impl_type,
2100 )
2103def dag_data(arr) -> Array:
2104 """Conjugate transpose of a raw array, dispatching to the right backend.
2106 Iterates through registered :class:`QarrayImpl` subclasses and delegates
2107 to the first one whose :meth:`~QarrayImpl.can_handle_data` returns True.
2108 Adding a new backend automatically extends this function — no changes
2109 required here.
2111 Args:
2112 arr: Input array (``jnp.ndarray``, ``sparse.BCOO``, or any type
2113 handled by a registered impl). For 1-D dense arrays only
2114 conjugation is applied (no transpose).
2116 Returns:
2117 Conjugate transpose with the last two axes swapped.
2119 Raises:
2120 TypeError: If no registered impl can handle *arr*.
2121 """
2122 for impl_class in _IMPL_REGISTRY:
2123 if impl_class.can_handle_data(arr):
2124 return impl_class.dag_data(arr)
2125 raise TypeError(f"dag_data: no registered impl can handle type {type(arr)}")
2128def ket2dm(qarr: Qarray) -> Qarray:
2129 """Convert a ket to a density matrix via outer product.
2131 Args:
2132 qarr: Ket, bra, or operator. Operators are returned unchanged.
2134 Returns:
2135 Density matrix ``|ψ⟩⟨ψ|``.
2136 """
2137 if qarr.qtype == Qtypes.oper:
2138 return qarr
2140 if qarr.qtype == Qtypes.bra:
2141 qarr = qarr.dag()
2143 return qarr @ qarr.dag()
2146# Data level operations
2147def is_dm_data(data: Array) -> bool:
2148 """Check whether *data* has the shape of a density matrix (square matrix).
2150 Args:
2151 data: Array to check.
2153 Returns:
2154 True if the last two dimensions are equal.
2155 """
2156 return data.shape[-2] == data.shape[-1]
2159def powm_data(data: Array, n: int) -> Array:
2160 """Integer matrix power of a raw array.
2162 Args:
2163 data: Dense square matrix array.
2164 n: Integer exponent.
2166 Returns:
2167 The *n*-th matrix power.
2168 """
2169 return jnp.linalg.matrix_power(data, n)
2172# Type aliases for readability
2173DenseQarray = Qarray[DenseImpl]
2174# SparseBCOOQarray and SparseDIAQarray are defined lazily (impls imported at runtime)
2175# Use Qarray[SparseBCOOImpl] / Qarray[SparseDiaImpl] once those modules are imported.
2177ARRAY_TYPES = (Array, ndarray, Qarray)