Skip to content

sparse_bcoo

Sparse BCOO backend for Qarray.

Implements the JAX experimental BCOO sparse format as a Qarray storage backend.

SparseBCOOImpl

Bases: QarrayImpl

Sparse implementation using JAX experimental BCOO sparse arrays.

Attributes:

Name Type Description
_data BCOO

The underlying sparse.BCOO array.

Source code in jaxquantum/core/sparse_bcoo.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@struct.dataclass
class SparseBCOOImpl(QarrayImpl):
    """Sparse implementation using JAX experimental BCOO sparse arrays.

    Attributes:
        _data: The underlying ``sparse.BCOO`` array.
    """

    _data: sparse.BCOO

    PROMOTION_ORDER = 1  # noqa: RUF012 — not a struct field; no annotation intentional

    @classmethod
    def from_data(cls, data) -> "SparseBCOOImpl":
        """Wrap *data* in a new ``SparseBCOOImpl``, converting to BCOO if needed.

        Args:
            data: A ``sparse.BCOO`` or array-like input.

        Returns:
            A ``SparseBCOOImpl`` wrapping a BCOO representation of *data*.
        """
        return cls(_data=cls._to_sparse(data))

    def get_data(self) -> Array:
        """Return the underlying BCOO sparse array."""
        return self._data

    def matmul(self, other: QarrayImpl) -> QarrayImpl:
        """Matrix multiply ``self @ other``.

        When *other* is a ``DenseImpl``, JAX's native BCOO @ dense path is
        used (no self-densification).  When *other* is also a
        ``SparseBCOOImpl``, a sparse @ sparse product is performed.

        Args:
            other: Right-hand operand.

        Returns:
            A ``DenseImpl`` (sparse @ dense) or ``SparseBCOOImpl`` (sparse @
            sparse) containing the matrix product.
        """
        if isinstance(other, DenseImpl):
            return DenseImpl(self._data @ other._data)
        a, b = self._coerce(other)
        if a is not self:
            return a.matmul(b)
        return SparseBCOOImpl(self._data @ b._data)

    def add(self, other: QarrayImpl) -> QarrayImpl:
        """Element-wise addition ``self + other``, coercing types as needed.

        Args:
            other: Right-hand operand.

        Returns:
            A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) sum.
        """
        a, b = self._coerce(other)
        if a is not self:
            return a.add(b)
        x, y = self._data, b._data
        if x.indices.dtype != y.indices.dtype:
            y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape)
        return SparseBCOOImpl(x + y)

    def sub(self, other: QarrayImpl) -> QarrayImpl:
        """Element-wise subtraction ``self - other``, coercing types as needed.

        Args:
            other: Right-hand operand.

        Returns:
            A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) difference.
        """
        a, b = self._coerce(other)
        if a is not self:
            return a.sub(b)
        x, y = self._data, b._data
        if x.indices.dtype != y.indices.dtype:
            y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape)
        return SparseBCOOImpl(x - y)

    def mul(self, scalar) -> QarrayImpl:
        """Scalar multiplication.

        Args:
            scalar: Scalar value.

        Returns:
            A ``SparseBCOOImpl`` with each stored value multiplied by *scalar*.
        """
        return SparseBCOOImpl(scalar * self._data)

    def dag(self) -> QarrayImpl:
        """Conjugate transpose without densifying.

        Transposes the last two dimensions of the BCOO array and conjugates
        the stored values.

        Returns:
            A ``SparseBCOOImpl`` containing the conjugate transpose.
        """
        ndim = self._data.ndim
        if ndim >= 2:
            permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
            transposed_data = sparse.bcoo_transpose(self._data, permutation=permutation)
        else:
            transposed_data = self._data

        conjugated_data = sparse.BCOO(
            (jnp.conj(transposed_data.data), transposed_data.indices),
            shape=transposed_data.shape,
        )
        return SparseBCOOImpl(conjugated_data)

    def to_dense(self) -> "DenseImpl":
        """Convert to a ``DenseImpl`` via ``todense()``.

        Returns:
            A ``DenseImpl`` with the same values as this sparse array.
        """
        return DenseImpl(self._data.todense())

    @classmethod
    def _to_sparse(cls, data) -> sparse.BCOO:
        """Convert *data* to a ``sparse.BCOO``, returning it unchanged if already sparse.

        Args:
            data: A ``sparse.BCOO`` or array-like.

        Returns:
            A ``sparse.BCOO`` representation of *data*.
        """
        if isinstance(data, sparse.BCOO):
            return data
        return sparse.BCOO.fromdense(data)

    def to_sparse_bcoo(self) -> "SparseBCOOImpl":
        """Return self (already sparse BCOO).

        Returns:
            This ``SparseBCOOImpl`` instance unchanged.
        """
        return self

    def shape(self) -> tuple:
        """Shape of the underlying BCOO array.

        Returns:
            Tuple of dimension sizes.
        """
        return self._data.shape

    def dtype(self):
        """Data type of the underlying BCOO array.

        Returns:
            The dtype of ``_data``.
        """
        return self._data.dtype

    def frobenius_norm(self) -> float:
        """Compute the Frobenius norm directly from stored values.

        Returns:
            The Frobenius norm as a scalar.
        """
        return jnp.sqrt(jnp.sum(jnp.abs(self._data.data) ** 2))

    @classmethod
    def _real(cls, data):
        """Return a BCOO array with only the real parts of the stored values."""
        return sparse.BCOO((jnp.real(data.data), data.indices), shape=data.shape)

    def real(self) -> QarrayImpl:
        """Element-wise real part.

        Returns:
            A ``SparseBCOOImpl`` containing the real parts of stored values.
        """
        return SparseBCOOImpl(SparseBCOOImpl._real(self._data))

    @classmethod
    def _imag(cls, data):
        """Return a BCOO array with only the imaginary parts of the stored values."""
        return sparse.BCOO((jnp.imag(data.data), data.indices), shape=data.shape)

    def imag(self) -> QarrayImpl:
        """Element-wise imaginary part.

        Returns:
            A ``SparseBCOOImpl`` containing the imaginary parts of stored values.
        """
        return SparseBCOOImpl(SparseBCOOImpl._imag(self._data))

    @classmethod
    def _conj(cls, data):
        """Return a BCOO array with complex-conjugated stored values."""
        return sparse.BCOO((jnp.conj(data.data), data.indices), shape=data.shape)

    def conj(self) -> QarrayImpl:
        """Element-wise complex conjugate.

        Returns:
            A ``SparseBCOOImpl`` containing the complex-conjugated stored values.
        """
        return SparseBCOOImpl(SparseBCOOImpl._conj(self._data))

    @classmethod
    def _abs(cls, data):
        """Return a BCOO array with absolute values of stored entries."""
        return sparse.sparsify(jnp.abs)(data)

    def abs(self) -> QarrayImpl:
        """Element-wise absolute value.

        Returns:
            A ``SparseBCOOImpl`` containing the absolute values of stored entries.
        """
        return SparseBCOOImpl(SparseBCOOImpl._abs(self._data))

    @classmethod
    def _eye_data(cls, n: int, dtype=None):
        """Create an ``n x n`` identity matrix as a sparse BCOO with O(n) memory.

        Args:
            n: Matrix size.
            dtype: Optional data type.

        Returns:
            A ``sparse.BCOO`` identity matrix of shape ``(n, n)``.
        """
        return sparse.eye(n, dtype=dtype)

    @classmethod
    def can_handle_data(cls, arr) -> bool:
        """Return True when *arr* is a ``sparse.BCOO`` array.

        Args:
            arr: Raw array.

        Returns:
            True if *arr* is a ``sparse.BCOO`` instance.
        """
        return isinstance(arr, sparse.BCOO)

    @classmethod
    def dag_data(cls, arr: sparse.BCOO) -> sparse.BCOO:
        """Conjugate transpose for BCOO sparse arrays without densifying.

        Args:
            arr: A ``sparse.BCOO`` array with ``ndim >= 2``.

        Returns:
            A ``sparse.BCOO`` containing the conjugate transpose.
        """
        ndim = arr.ndim
        permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
        transposed = sparse.bcoo_transpose(arr, permutation=permutation)
        return sparse.BCOO(
            (jnp.conj(transposed.data), transposed.indices),
            shape=transposed.shape,
        )

    def trace(self) -> Array:
        """Compute the trace of the last two matrix dimensions without densifying.

        Returns:
            Trace value(s).
        """
        indices = self._data.indices
        values = self._data.data
        ndim = indices.shape[-1]

        is_diag = indices[:, -2] == indices[:, -1]

        if ndim == 2:
            return jnp.sum(values * is_diag)
        else:
            batch_shape = self._data.shape[:-2]
            B = int(jnp.prod(jnp.array(batch_shape)))
            strides = [1]
            for s in reversed(batch_shape[1:]):
                strides.insert(0, strides[0] * s)
            strides = jnp.array(strides, dtype=jnp.int32)
            flat_batch_idx = jnp.sum(indices[:, :-2] * strides, axis=-1)
            result = jnp.zeros(B, dtype=values.dtype).at[flat_batch_idx].add(
                values * is_diag
            )
            return result.reshape(batch_shape)

    def keep_only_diag(self) -> "SparseBCOOImpl":
        """Zero out off-diagonal stored entries without densifying.

        Returns:
            A ``SparseBCOOImpl`` with only diagonal entries non-zero.
        """
        indices = self._data.indices
        values = self._data.data
        is_diag = indices[:, -2] == indices[:, -1]
        new_values = values * is_diag
        return SparseBCOOImpl(sparse.BCOO((new_values, indices), shape=self._data.shape))

    def l2_norm_batched(self, bdims: tuple) -> Array:
        """Compute the L2 norm per batch element without densifying.

        Args:
            bdims: Tuple of batch dimension sizes.

        Returns:
            Scalar or array of L2 norms.
        """
        values = self._data.data
        indices = self._data.indices
        n_batch_dims = len(bdims)
        sq = jnp.abs(values) ** 2

        if n_batch_dims == 0:
            return jnp.sqrt(jnp.sum(sq))
        else:
            B = int(jnp.prod(jnp.array(bdims)))
            strides = [1]
            for s in reversed(bdims[1:]):
                strides.insert(0, strides[0] * s)
            strides = jnp.array(strides, dtype=jnp.int32)
            flat_batch_idx = jnp.sum(indices[:, :n_batch_dims] * strides, axis=-1)
            sum_sq = (
                jnp.zeros(B, dtype=jnp.float64)
                .at[flat_batch_idx]
                .add(sq)
            )
            return jnp.sqrt(sum_sq).reshape(bdims)

    def __deepcopy__(self, memo=None):
        return SparseBCOOImpl(_data=deepcopy(self._data, memo))

    def tidy_up(self, atol) -> "SparseBCOOImpl":
        """Zero out stored values whose real or imaginary magnitude is below *atol*.

        Args:
            atol: Absolute tolerance threshold.

        Returns:
            A new ``SparseBCOOImpl`` with small values zeroed.
        """
        values = self._data.data
        re = jnp.real(values)
        im = jnp.imag(values)
        new_values = re * (jnp.abs(re) > atol) + 1j * im * (jnp.abs(im) > atol)
        return SparseBCOOImpl(
            sparse.BCOO((new_values, self._data.indices), shape=self._data.shape)
        )

    def kron(self, other: "QarrayImpl") -> "QarrayImpl":
        """Kronecker product using ``sparsify(jnp.kron)`` — stays sparse.

        Args:
            other: Right-hand operand.

        Returns:
            A ``SparseBCOOImpl`` containing the Kronecker product when both
            operands are sparse; a ``DenseImpl`` when types differ.
        """
        a, b = self._coerce(other)
        if a is not self:
            return a.kron(b)
        sparse_kron = sparse.sparsify(jnp.kron)
        return SparseBCOOImpl(sparse_kron(self._data, b._data))

abs()

Element-wise absolute value.

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl containing the absolute values of stored entries.

Source code in jaxquantum/core/sparse_bcoo.py
233
234
235
236
237
238
239
def abs(self) -> QarrayImpl:
    """Element-wise absolute value.

    Returns:
        A ``SparseBCOOImpl`` containing the absolute values of stored entries.
    """
    return SparseBCOOImpl(SparseBCOOImpl._abs(self._data))

add(other)

Element-wise addition self + other, coercing types as needed.

Parameters:

Name Type Description Default
other QarrayImpl

Right-hand operand.

required

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl (both sparse) or DenseImpl (mixed) sum.

Source code in jaxquantum/core/sparse_bcoo.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def add(self, other: QarrayImpl) -> QarrayImpl:
    """Element-wise addition ``self + other``, coercing types as needed.

    Args:
        other: Right-hand operand.

    Returns:
        A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) sum.
    """
    a, b = self._coerce(other)
    if a is not self:
        return a.add(b)
    x, y = self._data, b._data
    if x.indices.dtype != y.indices.dtype:
        y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape)
    return SparseBCOOImpl(x + y)

can_handle_data(arr) classmethod

Return True when arr is a sparse.BCOO array.

Parameters:

Name Type Description Default
arr

Raw array.

required

Returns:

Type Description
bool

True if arr is a sparse.BCOO instance.

Source code in jaxquantum/core/sparse_bcoo.py
254
255
256
257
258
259
260
261
262
263
264
@classmethod
def can_handle_data(cls, arr) -> bool:
    """Return True when *arr* is a ``sparse.BCOO`` array.

    Args:
        arr: Raw array.

    Returns:
        True if *arr* is a ``sparse.BCOO`` instance.
    """
    return isinstance(arr, sparse.BCOO)

conj()

Element-wise complex conjugate.

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl containing the complex-conjugated stored values.

Source code in jaxquantum/core/sparse_bcoo.py
220
221
222
223
224
225
226
def conj(self) -> QarrayImpl:
    """Element-wise complex conjugate.

    Returns:
        A ``SparseBCOOImpl`` containing the complex-conjugated stored values.
    """
    return SparseBCOOImpl(SparseBCOOImpl._conj(self._data))

dag()

Conjugate transpose without densifying.

Transposes the last two dimensions of the BCOO array and conjugates the stored values.

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl containing the conjugate transpose.

Source code in jaxquantum/core/sparse_bcoo.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def dag(self) -> QarrayImpl:
    """Conjugate transpose without densifying.

    Transposes the last two dimensions of the BCOO array and conjugates
    the stored values.

    Returns:
        A ``SparseBCOOImpl`` containing the conjugate transpose.
    """
    ndim = self._data.ndim
    if ndim >= 2:
        permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
        transposed_data = sparse.bcoo_transpose(self._data, permutation=permutation)
    else:
        transposed_data = self._data

    conjugated_data = sparse.BCOO(
        (jnp.conj(transposed_data.data), transposed_data.indices),
        shape=transposed_data.shape,
    )
    return SparseBCOOImpl(conjugated_data)

dag_data(arr) classmethod

Conjugate transpose for BCOO sparse arrays without densifying.

Parameters:

Name Type Description Default
arr BCOO

A sparse.BCOO array with ndim >= 2.

required

Returns:

Type Description
BCOO

A sparse.BCOO containing the conjugate transpose.

Source code in jaxquantum/core/sparse_bcoo.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
@classmethod
def dag_data(cls, arr: sparse.BCOO) -> sparse.BCOO:
    """Conjugate transpose for BCOO sparse arrays without densifying.

    Args:
        arr: A ``sparse.BCOO`` array with ``ndim >= 2``.

    Returns:
        A ``sparse.BCOO`` containing the conjugate transpose.
    """
    ndim = arr.ndim
    permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
    transposed = sparse.bcoo_transpose(arr, permutation=permutation)
    return sparse.BCOO(
        (jnp.conj(transposed.data), transposed.indices),
        shape=transposed.shape,
    )

dtype()

Data type of the underlying BCOO array.

Returns:

Type Description

The dtype of _data.

Source code in jaxquantum/core/sparse_bcoo.py
173
174
175
176
177
178
179
def dtype(self):
    """Data type of the underlying BCOO array.

    Returns:
        The dtype of ``_data``.
    """
    return self._data.dtype

frobenius_norm()

Compute the Frobenius norm directly from stored values.

Returns:

Type Description
float

The Frobenius norm as a scalar.

Source code in jaxquantum/core/sparse_bcoo.py
181
182
183
184
185
186
187
def frobenius_norm(self) -> float:
    """Compute the Frobenius norm directly from stored values.

    Returns:
        The Frobenius norm as a scalar.
    """
    return jnp.sqrt(jnp.sum(jnp.abs(self._data.data) ** 2))

from_data(data) classmethod

Wrap data in a new SparseBCOOImpl, converting to BCOO if needed.

Parameters:

Name Type Description Default
data

A sparse.BCOO or array-like input.

required

Returns:

Type Description
'SparseBCOOImpl'

A SparseBCOOImpl wrapping a BCOO representation of data.

Source code in jaxquantum/core/sparse_bcoo.py
31
32
33
34
35
36
37
38
39
40
41
@classmethod
def from_data(cls, data) -> "SparseBCOOImpl":
    """Wrap *data* in a new ``SparseBCOOImpl``, converting to BCOO if needed.

    Args:
        data: A ``sparse.BCOO`` or array-like input.

    Returns:
        A ``SparseBCOOImpl`` wrapping a BCOO representation of *data*.
    """
    return cls(_data=cls._to_sparse(data))

get_data()

Return the underlying BCOO sparse array.

Source code in jaxquantum/core/sparse_bcoo.py
43
44
45
def get_data(self) -> Array:
    """Return the underlying BCOO sparse array."""
    return self._data

imag()

Element-wise imaginary part.

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl containing the imaginary parts of stored values.

Source code in jaxquantum/core/sparse_bcoo.py
207
208
209
210
211
212
213
def imag(self) -> QarrayImpl:
    """Element-wise imaginary part.

    Returns:
        A ``SparseBCOOImpl`` containing the imaginary parts of stored values.
    """
    return SparseBCOOImpl(SparseBCOOImpl._imag(self._data))

keep_only_diag()

Zero out off-diagonal stored entries without densifying.

Returns:

Type Description
'SparseBCOOImpl'

A SparseBCOOImpl with only diagonal entries non-zero.

Source code in jaxquantum/core/sparse_bcoo.py
311
312
313
314
315
316
317
318
319
320
321
def keep_only_diag(self) -> "SparseBCOOImpl":
    """Zero out off-diagonal stored entries without densifying.

    Returns:
        A ``SparseBCOOImpl`` with only diagonal entries non-zero.
    """
    indices = self._data.indices
    values = self._data.data
    is_diag = indices[:, -2] == indices[:, -1]
    new_values = values * is_diag
    return SparseBCOOImpl(sparse.BCOO((new_values, indices), shape=self._data.shape))

kron(other)

Kronecker product using sparsify(jnp.kron) — stays sparse.

Parameters:

Name Type Description Default
other 'QarrayImpl'

Right-hand operand.

required

Returns:

Type Description
'QarrayImpl'

A SparseBCOOImpl containing the Kronecker product when both

'QarrayImpl'

operands are sparse; a DenseImpl when types differ.

Source code in jaxquantum/core/sparse_bcoo.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def kron(self, other: "QarrayImpl") -> "QarrayImpl":
    """Kronecker product using ``sparsify(jnp.kron)`` — stays sparse.

    Args:
        other: Right-hand operand.

    Returns:
        A ``SparseBCOOImpl`` containing the Kronecker product when both
        operands are sparse; a ``DenseImpl`` when types differ.
    """
    a, b = self._coerce(other)
    if a is not self:
        return a.kron(b)
    sparse_kron = sparse.sparsify(jnp.kron)
    return SparseBCOOImpl(sparse_kron(self._data, b._data))

l2_norm_batched(bdims)

Compute the L2 norm per batch element without densifying.

Parameters:

Name Type Description Default
bdims tuple

Tuple of batch dimension sizes.

required

Returns:

Type Description
Array

Scalar or array of L2 norms.

Source code in jaxquantum/core/sparse_bcoo.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def l2_norm_batched(self, bdims: tuple) -> Array:
    """Compute the L2 norm per batch element without densifying.

    Args:
        bdims: Tuple of batch dimension sizes.

    Returns:
        Scalar or array of L2 norms.
    """
    values = self._data.data
    indices = self._data.indices
    n_batch_dims = len(bdims)
    sq = jnp.abs(values) ** 2

    if n_batch_dims == 0:
        return jnp.sqrt(jnp.sum(sq))
    else:
        B = int(jnp.prod(jnp.array(bdims)))
        strides = [1]
        for s in reversed(bdims[1:]):
            strides.insert(0, strides[0] * s)
        strides = jnp.array(strides, dtype=jnp.int32)
        flat_batch_idx = jnp.sum(indices[:, :n_batch_dims] * strides, axis=-1)
        sum_sq = (
            jnp.zeros(B, dtype=jnp.float64)
            .at[flat_batch_idx]
            .add(sq)
        )
        return jnp.sqrt(sum_sq).reshape(bdims)

matmul(other)

Matrix multiply self @ other.

When other is a DenseImpl, JAX's native BCOO @ dense path is used (no self-densification). When other is also a SparseBCOOImpl, a sparse @ sparse product is performed.

Parameters:

Name Type Description Default
other QarrayImpl

Right-hand operand.

required

Returns:

Type Description
QarrayImpl

A DenseImpl (sparse @ dense) or SparseBCOOImpl (sparse @

QarrayImpl

sparse) containing the matrix product.

Source code in jaxquantum/core/sparse_bcoo.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def matmul(self, other: QarrayImpl) -> QarrayImpl:
    """Matrix multiply ``self @ other``.

    When *other* is a ``DenseImpl``, JAX's native BCOO @ dense path is
    used (no self-densification).  When *other* is also a
    ``SparseBCOOImpl``, a sparse @ sparse product is performed.

    Args:
        other: Right-hand operand.

    Returns:
        A ``DenseImpl`` (sparse @ dense) or ``SparseBCOOImpl`` (sparse @
        sparse) containing the matrix product.
    """
    if isinstance(other, DenseImpl):
        return DenseImpl(self._data @ other._data)
    a, b = self._coerce(other)
    if a is not self:
        return a.matmul(b)
    return SparseBCOOImpl(self._data @ b._data)

mul(scalar)

Scalar multiplication.

Parameters:

Name Type Description Default
scalar

Scalar value.

required

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl with each stored value multiplied by scalar.

Source code in jaxquantum/core/sparse_bcoo.py
102
103
104
105
106
107
108
109
110
111
def mul(self, scalar) -> QarrayImpl:
    """Scalar multiplication.

    Args:
        scalar: Scalar value.

    Returns:
        A ``SparseBCOOImpl`` with each stored value multiplied by *scalar*.
    """
    return SparseBCOOImpl(scalar * self._data)

real()

Element-wise real part.

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl containing the real parts of stored values.

Source code in jaxquantum/core/sparse_bcoo.py
194
195
196
197
198
199
200
def real(self) -> QarrayImpl:
    """Element-wise real part.

    Returns:
        A ``SparseBCOOImpl`` containing the real parts of stored values.
    """
    return SparseBCOOImpl(SparseBCOOImpl._real(self._data))

shape()

Shape of the underlying BCOO array.

Returns:

Type Description
tuple

Tuple of dimension sizes.

Source code in jaxquantum/core/sparse_bcoo.py
165
166
167
168
169
170
171
def shape(self) -> tuple:
    """Shape of the underlying BCOO array.

    Returns:
        Tuple of dimension sizes.
    """
    return self._data.shape

sub(other)

Element-wise subtraction self - other, coercing types as needed.

Parameters:

Name Type Description Default
other QarrayImpl

Right-hand operand.

required

Returns:

Type Description
QarrayImpl

A SparseBCOOImpl (both sparse) or DenseImpl (mixed) difference.

Source code in jaxquantum/core/sparse_bcoo.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def sub(self, other: QarrayImpl) -> QarrayImpl:
    """Element-wise subtraction ``self - other``, coercing types as needed.

    Args:
        other: Right-hand operand.

    Returns:
        A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) difference.
    """
    a, b = self._coerce(other)
    if a is not self:
        return a.sub(b)
    x, y = self._data, b._data
    if x.indices.dtype != y.indices.dtype:
        y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape)
    return SparseBCOOImpl(x - y)

tidy_up(atol)

Zero out stored values whose real or imaginary magnitude is below atol.

Parameters:

Name Type Description Default
atol

Absolute tolerance threshold.

required

Returns:

Type Description
'SparseBCOOImpl'

A new SparseBCOOImpl with small values zeroed.

Source code in jaxquantum/core/sparse_bcoo.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def tidy_up(self, atol) -> "SparseBCOOImpl":
    """Zero out stored values whose real or imaginary magnitude is below *atol*.

    Args:
        atol: Absolute tolerance threshold.

    Returns:
        A new ``SparseBCOOImpl`` with small values zeroed.
    """
    values = self._data.data
    re = jnp.real(values)
    im = jnp.imag(values)
    new_values = re * (jnp.abs(re) > atol) + 1j * im * (jnp.abs(im) > atol)
    return SparseBCOOImpl(
        sparse.BCOO((new_values, self._data.indices), shape=self._data.shape)
    )

to_dense()

Convert to a DenseImpl via todense().

Returns:

Type Description
'DenseImpl'

A DenseImpl with the same values as this sparse array.

Source code in jaxquantum/core/sparse_bcoo.py
135
136
137
138
139
140
141
def to_dense(self) -> "DenseImpl":
    """Convert to a ``DenseImpl`` via ``todense()``.

    Returns:
        A ``DenseImpl`` with the same values as this sparse array.
    """
    return DenseImpl(self._data.todense())

to_sparse_bcoo()

Return self (already sparse BCOO).

Returns:

Type Description
'SparseBCOOImpl'

This SparseBCOOImpl instance unchanged.

Source code in jaxquantum/core/sparse_bcoo.py
157
158
159
160
161
162
163
def to_sparse_bcoo(self) -> "SparseBCOOImpl":
    """Return self (already sparse BCOO).

    Returns:
        This ``SparseBCOOImpl`` instance unchanged.
    """
    return self

trace()

Compute the trace of the last two matrix dimensions without densifying.

Returns:

Type Description
Array

Trace value(s).

Source code in jaxquantum/core/sparse_bcoo.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def trace(self) -> Array:
    """Compute the trace of the last two matrix dimensions without densifying.

    Returns:
        Trace value(s).
    """
    indices = self._data.indices
    values = self._data.data
    ndim = indices.shape[-1]

    is_diag = indices[:, -2] == indices[:, -1]

    if ndim == 2:
        return jnp.sum(values * is_diag)
    else:
        batch_shape = self._data.shape[:-2]
        B = int(jnp.prod(jnp.array(batch_shape)))
        strides = [1]
        for s in reversed(batch_shape[1:]):
            strides.insert(0, strides[0] * s)
        strides = jnp.array(strides, dtype=jnp.int32)
        flat_batch_idx = jnp.sum(indices[:, :-2] * strides, axis=-1)
        result = jnp.zeros(B, dtype=values.dtype).at[flat_batch_idx].add(
            values * is_diag
        )
        return result.reshape(batch_shape)