Coverage for jaxquantum / core / sparse_bcoo.py: 92%
149 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""Sparse BCOO backend for Qarray.
3Implements the JAX experimental BCOO sparse format as a Qarray storage backend.
4"""
6from __future__ import annotations
8from flax import struct
9from jax import Array
10from copy import deepcopy
11import jax.numpy as jnp
12from jax.experimental import sparse
14# QarrayImpl and friends are imported below (after qarray.py is fully loaded)
15# to match the pattern used by sparse_dia.py and avoid circular imports.
16from jaxquantum.core.qarray import QarrayImpl, DenseImpl, QarrayImplType # noqa: E402
19@struct.dataclass
20class SparseBCOOImpl(QarrayImpl):
21 """Sparse implementation using JAX experimental BCOO sparse arrays.
23 Attributes:
24 _data: The underlying ``sparse.BCOO`` array.
25 """
27 _data: sparse.BCOO
29 PROMOTION_ORDER = 1 # noqa: RUF012 — not a struct field; no annotation intentional
31 @classmethod
32 def from_data(cls, data) -> "SparseBCOOImpl":
33 """Wrap *data* in a new ``SparseBCOOImpl``, converting to BCOO if needed.
35 Args:
36 data: A ``sparse.BCOO`` or array-like input.
38 Returns:
39 A ``SparseBCOOImpl`` wrapping a BCOO representation of *data*.
40 """
41 return cls(_data=cls._to_sparse(data))
43 def get_data(self) -> Array:
44 """Return the underlying BCOO sparse array."""
45 return self._data
47 def matmul(self, other: QarrayImpl) -> QarrayImpl:
48 """Matrix multiply ``self @ other``.
50 When *other* is a ``DenseImpl``, JAX's native BCOO @ dense path is
51 used (no self-densification). When *other* is also a
52 ``SparseBCOOImpl``, a sparse @ sparse product is performed.
54 Args:
55 other: Right-hand operand.
57 Returns:
58 A ``DenseImpl`` (sparse @ dense) or ``SparseBCOOImpl`` (sparse @
59 sparse) containing the matrix product.
60 """
61 if isinstance(other, DenseImpl):
62 return DenseImpl(self._data @ other._data)
63 a, b = self._coerce(other)
64 if a is not self:
65 return a.matmul(b)
66 return SparseBCOOImpl(self._data @ b._data)
68 def add(self, other: QarrayImpl) -> QarrayImpl:
69 """Element-wise addition ``self + other``, coercing types as needed.
71 Args:
72 other: Right-hand operand.
74 Returns:
75 A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) sum.
76 """
77 a, b = self._coerce(other)
78 if a is not self:
79 return a.add(b)
80 x, y = self._data, b._data
81 if x.indices.dtype != y.indices.dtype:
82 y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape)
83 return SparseBCOOImpl(x + y)
85 def sub(self, other: QarrayImpl) -> QarrayImpl:
86 """Element-wise subtraction ``self - other``, coercing types as needed.
88 Args:
89 other: Right-hand operand.
91 Returns:
92 A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) difference.
93 """
94 a, b = self._coerce(other)
95 if a is not self:
96 return a.sub(b)
97 x, y = self._data, b._data
98 if x.indices.dtype != y.indices.dtype:
99 y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape)
100 return SparseBCOOImpl(x - y)
102 def mul(self, scalar) -> QarrayImpl:
103 """Scalar multiplication.
105 Args:
106 scalar: Scalar value.
108 Returns:
109 A ``SparseBCOOImpl`` with each stored value multiplied by *scalar*.
110 """
111 return SparseBCOOImpl(scalar * self._data)
113 def dag(self) -> QarrayImpl:
114 """Conjugate transpose without densifying.
116 Transposes the last two dimensions of the BCOO array and conjugates
117 the stored values.
119 Returns:
120 A ``SparseBCOOImpl`` containing the conjugate transpose.
121 """
122 ndim = self._data.ndim
123 if ndim >= 2:
124 permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
125 transposed_data = sparse.bcoo_transpose(self._data, permutation=permutation)
126 else:
127 transposed_data = self._data
129 conjugated_data = sparse.BCOO(
130 (jnp.conj(transposed_data.data), transposed_data.indices),
131 shape=transposed_data.shape,
132 )
133 return SparseBCOOImpl(conjugated_data)
135 def to_dense(self) -> "DenseImpl":
136 """Convert to a ``DenseImpl`` via ``todense()``.
138 Returns:
139 A ``DenseImpl`` with the same values as this sparse array.
140 """
141 return DenseImpl(self._data.todense())
143 @classmethod
144 def _to_sparse(cls, data) -> sparse.BCOO:
145 """Convert *data* to a ``sparse.BCOO``, returning it unchanged if already sparse.
147 Args:
148 data: A ``sparse.BCOO`` or array-like.
150 Returns:
151 A ``sparse.BCOO`` representation of *data*.
152 """
153 if isinstance(data, sparse.BCOO):
154 return data
155 return sparse.BCOO.fromdense(data)
157 def to_sparse_bcoo(self) -> "SparseBCOOImpl":
158 """Return self (already sparse BCOO).
160 Returns:
161 This ``SparseBCOOImpl`` instance unchanged.
162 """
163 return self
165 def shape(self) -> tuple:
166 """Shape of the underlying BCOO array.
168 Returns:
169 Tuple of dimension sizes.
170 """
171 return self._data.shape
173 def dtype(self):
174 """Data type of the underlying BCOO array.
176 Returns:
177 The dtype of ``_data``.
178 """
179 return self._data.dtype
181 def frobenius_norm(self) -> float:
182 """Compute the Frobenius norm directly from stored values.
184 Returns:
185 The Frobenius norm as a scalar.
186 """
187 return jnp.sqrt(jnp.sum(jnp.abs(self._data.data) ** 2))
189 @classmethod
190 def _real(cls, data):
191 """Return a BCOO array with only the real parts of the stored values."""
192 return sparse.BCOO((jnp.real(data.data), data.indices), shape=data.shape)
194 def real(self) -> QarrayImpl:
195 """Element-wise real part.
197 Returns:
198 A ``SparseBCOOImpl`` containing the real parts of stored values.
199 """
200 return SparseBCOOImpl(SparseBCOOImpl._real(self._data))
202 @classmethod
203 def _imag(cls, data):
204 """Return a BCOO array with only the imaginary parts of the stored values."""
205 return sparse.BCOO((jnp.imag(data.data), data.indices), shape=data.shape)
207 def imag(self) -> QarrayImpl:
208 """Element-wise imaginary part.
210 Returns:
211 A ``SparseBCOOImpl`` containing the imaginary parts of stored values.
212 """
213 return SparseBCOOImpl(SparseBCOOImpl._imag(self._data))
215 @classmethod
216 def _conj(cls, data):
217 """Return a BCOO array with complex-conjugated stored values."""
218 return sparse.BCOO((jnp.conj(data.data), data.indices), shape=data.shape)
220 def conj(self) -> QarrayImpl:
221 """Element-wise complex conjugate.
223 Returns:
224 A ``SparseBCOOImpl`` containing the complex-conjugated stored values.
225 """
226 return SparseBCOOImpl(SparseBCOOImpl._conj(self._data))
228 @classmethod
229 def _abs(cls, data):
230 """Return a BCOO array with absolute values of stored entries."""
231 return sparse.sparsify(jnp.abs)(data)
233 def abs(self) -> QarrayImpl:
234 """Element-wise absolute value.
236 Returns:
237 A ``SparseBCOOImpl`` containing the absolute values of stored entries.
238 """
239 return SparseBCOOImpl(SparseBCOOImpl._abs(self._data))
241 @classmethod
242 def _eye_data(cls, n: int, dtype=None):
243 """Create an ``n x n`` identity matrix as a sparse BCOO with O(n) memory.
245 Args:
246 n: Matrix size.
247 dtype: Optional data type.
249 Returns:
250 A ``sparse.BCOO`` identity matrix of shape ``(n, n)``.
251 """
252 return sparse.eye(n, dtype=dtype)
254 @classmethod
255 def can_handle_data(cls, arr) -> bool:
256 """Return True when *arr* is a ``sparse.BCOO`` array.
258 Args:
259 arr: Raw array.
261 Returns:
262 True if *arr* is a ``sparse.BCOO`` instance.
263 """
264 return isinstance(arr, sparse.BCOO)
266 @classmethod
267 def dag_data(cls, arr: sparse.BCOO) -> sparse.BCOO:
268 """Conjugate transpose for BCOO sparse arrays without densifying.
270 Args:
271 arr: A ``sparse.BCOO`` array with ``ndim >= 2``.
273 Returns:
274 A ``sparse.BCOO`` containing the conjugate transpose.
275 """
276 ndim = arr.ndim
277 permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
278 transposed = sparse.bcoo_transpose(arr, permutation=permutation)
279 return sparse.BCOO(
280 (jnp.conj(transposed.data), transposed.indices),
281 shape=transposed.shape,
282 )
284 def trace(self) -> Array:
285 """Compute the trace of the last two matrix dimensions without densifying.
287 Returns:
288 Trace value(s).
289 """
290 indices = self._data.indices
291 values = self._data.data
292 ndim = indices.shape[-1]
294 is_diag = indices[:, -2] == indices[:, -1]
296 if ndim == 2:
297 return jnp.sum(values * is_diag)
298 else:
299 batch_shape = self._data.shape[:-2]
300 B = int(jnp.prod(jnp.array(batch_shape)))
301 strides = [1]
302 for s in reversed(batch_shape[1:]):
303 strides.insert(0, strides[0] * s)
304 strides = jnp.array(strides, dtype=jnp.int32)
305 flat_batch_idx = jnp.sum(indices[:, :-2] * strides, axis=-1)
306 result = jnp.zeros(B, dtype=values.dtype).at[flat_batch_idx].add(
307 values * is_diag
308 )
309 return result.reshape(batch_shape)
311 def keep_only_diag(self) -> "SparseBCOOImpl":
312 """Zero out off-diagonal stored entries without densifying.
314 Returns:
315 A ``SparseBCOOImpl`` with only diagonal entries non-zero.
316 """
317 indices = self._data.indices
318 values = self._data.data
319 is_diag = indices[:, -2] == indices[:, -1]
320 new_values = values * is_diag
321 return SparseBCOOImpl(sparse.BCOO((new_values, indices), shape=self._data.shape))
323 def l2_norm_batched(self, bdims: tuple) -> Array:
324 """Compute the L2 norm per batch element without densifying.
326 Args:
327 bdims: Tuple of batch dimension sizes.
329 Returns:
330 Scalar or array of L2 norms.
331 """
332 values = self._data.data
333 indices = self._data.indices
334 n_batch_dims = len(bdims)
335 sq = jnp.abs(values) ** 2
337 if n_batch_dims == 0:
338 return jnp.sqrt(jnp.sum(sq))
339 else:
340 B = int(jnp.prod(jnp.array(bdims)))
341 strides = [1]
342 for s in reversed(bdims[1:]):
343 strides.insert(0, strides[0] * s)
344 strides = jnp.array(strides, dtype=jnp.int32)
345 flat_batch_idx = jnp.sum(indices[:, :n_batch_dims] * strides, axis=-1)
346 sum_sq = (
347 jnp.zeros(B, dtype=jnp.float64)
348 .at[flat_batch_idx]
349 .add(sq)
350 )
351 return jnp.sqrt(sum_sq).reshape(bdims)
353 def __deepcopy__(self, memo=None):
354 return SparseBCOOImpl(_data=deepcopy(self._data, memo))
356 def tidy_up(self, atol) -> "SparseBCOOImpl":
357 """Zero out stored values whose real or imaginary magnitude is below *atol*.
359 Args:
360 atol: Absolute tolerance threshold.
362 Returns:
363 A new ``SparseBCOOImpl`` with small values zeroed.
364 """
365 values = self._data.data
366 re = jnp.real(values)
367 im = jnp.imag(values)
368 new_values = re * (jnp.abs(re) > atol) + 1j * im * (jnp.abs(im) > atol)
369 return SparseBCOOImpl(
370 sparse.BCOO((new_values, self._data.indices), shape=self._data.shape)
371 )
373 def kron(self, other: "QarrayImpl") -> "QarrayImpl":
374 """Kronecker product using ``sparsify(jnp.kron)`` — stays sparse.
376 Args:
377 other: Right-hand operand.
379 Returns:
380 A ``SparseBCOOImpl`` containing the Kronecker product when both
381 operands are sparse; a ``DenseImpl`` when types differ.
382 """
383 a, b = self._coerce(other)
384 if a is not self:
385 return a.kron(b)
386 sparse_kron = sparse.sparsify(jnp.kron)
387 return SparseBCOOImpl(sparse_kron(self._data, b._data))
390# Register with the enum registry
391QarrayImplType.register(SparseBCOOImpl, QarrayImplType.SPARSE_BCOO)
393__all__ = ["SparseBCOOImpl"]