Coverage for jaxquantum / core / sparse_dia.py: 92%
252 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 22:49 +0000
1"""Sparse diagonal (SparseDIA) backend for Qarray.
3Stores only the diagonal *values* of a matrix, making quantum operators with
4small numbers of non-zero diagonals (annihilation, creation, number, Kerr…)
5far cheaper than Dense or BCOO:
7 * Memory: O(d * n) where d = number of stored diagonals, n = matrix size
8 * No index arrays (unlike BCOO which stores (row, col) per non-zero)
9 * ``_offsets`` is *static* Python metadata (pytree_node=False), so JAX
10 unrolls all loops over diagonals at compile time — only static slices,
11 no dynamic indexing or scatter/gather.
13Padding convention (Convention A):
15 For diagonal at offset k (k ≥ 0):
16 diags[..., i, j] = A[j-k, j] for j ∈ [k, n-1]; zeros at [0:k]
17 For diagonal at offset k (k < 0):
18 diags[..., i, j] = A[j-k, j] for j ∈ [0, n+k-1]; zeros at [n+k:]
20Unified access formula (holds for any k, out-of-range slots are zero):
21 A[i, i+k] = diags[..., diag_idx, i+k]
23This makes every matrix operation a set of aligned slice multiplications.
25Some improvements (_dia_slice helper, integer matrix power, diagonal-range pruning,
26offset detection) were identified by studying the dynamiqs library
27(https://github.com/dynamiqs/dynamiqs).
28"""
30from __future__ import annotations
32import numpy as np
33from copy import deepcopy
34from typing import TYPE_CHECKING
36import jax.numpy as jnp
37from flax import struct
38from jax import Array
40if TYPE_CHECKING:
41 from jaxquantum.core.qarray import DenseImpl, QarrayImplType
42 from jaxquantum.core.sparse_bcoo import SparseBCOOImpl
45# ---------------------------------------------------------------------------
46# Slice helper
47# ---------------------------------------------------------------------------
49def _dia_slice(k: int) -> slice:
50 """Slice selecting the valid data positions for diagonal offset k.
52 For k ≥ 0: valid data lives at column indices [k, n), so slice(k, None).
53 For k < 0: valid data lives at column indices [0, n+k), so slice(None, k).
55 The complementary slice (for the result rows / left-operand columns) is
56 always ``_dia_slice(-k)``.
57 """
58 return slice(k, None) if k >= 0 else slice(None, k)
61# ---------------------------------------------------------------------------
62# Raw data container
63# ---------------------------------------------------------------------------
65@struct.dataclass
66class SparseDiaData:
67 """Lightweight pytree-compatible container for sparse-diagonal raw data.
69 Returned by ``SparseDiaImpl.get_data()`` and consumed by
70 ``SparseDiaImpl.from_data()``. Registered as a JAX pytree via Flax's
71 ``@struct.dataclass``; ``offsets`` is *not* a pytree leaf (it is static
72 compile-time metadata).
74 Attributes:
75 offsets: Static tuple of diagonal offsets (pytree_node=False).
76 diags: JAX array of shape (*batch, n_diags, n) containing the
77 padded diagonal values.
78 """
80 offsets: tuple = struct.field(pytree_node=False)
81 diags: Array
83 # Class-level marker (not a dataclass field — no type annotation).
84 # Used by DenseImpl.can_handle_data to exclude SparseDiaData without
85 # a direct import (which would be circular).
86 _is_sparse_dia = True
88 @property
89 def shape(self) -> tuple:
90 """Shape of the represented square matrix (*batch, n, n)."""
91 n = self.diags.shape[-1]
92 return (*self.diags.shape[:-2], n, n)
94 @property
95 def dtype(self):
96 """Dtype of the stored diagonal values."""
97 return self.diags.dtype
99 def __mul__(self, scalar):
100 return SparseDiaData(offsets=self.offsets, diags=self.diags * scalar)
102 def __rmul__(self, scalar):
103 return SparseDiaData(offsets=self.offsets, diags=scalar * self.diags)
105 def __getitem__(self, index):
106 """Index into the batch dimension(s), preserving offsets."""
107 return SparseDiaData(offsets=self.offsets, diags=self.diags[index])
109 def __len__(self):
110 """Number of elements along the leading batch dimension."""
111 return self.shape[0]
113 def reshape(self, *new_shape):
114 """Reshape batch dimensions while preserving diagonal structure.
116 ``new_shape`` must end with ``(N, N)`` (the matrix dims are unchanged).
117 Only the leading batch dims are reshaped.
118 """
119 new_batch = new_shape[:-2]
120 n = self.diags.shape[-1]
121 new_diags = self.diags.reshape(*new_batch, len(self.offsets), n)
122 return SparseDiaData(offsets=self.offsets, diags=new_diags)
124 def __matmul__(self, other):
125 """SparseDIA @ dense → dense (used by mesolve ODE RHS)."""
126 # _sparsedia_matmul_dense is defined later in this module; Python
127 # resolves the name at call time so forward reference is fine.
128 return _sparsedia_matmul_dense(self.offsets, self.diags, other)
130 def __rmatmul__(self, other):
131 """dense @ SparseDIA → dense (used by mesolve ODE RHS)."""
132 return _sparsedia_rmatmul_dense(other, self.offsets, self.diags)
135# ---------------------------------------------------------------------------
136# Helper: dense → SparseDIA conversion
137# ---------------------------------------------------------------------------
139def _dense_to_sparsedia(arr: np.ndarray) -> tuple[tuple, np.ndarray]:
140 """Extract non-zero diagonal offsets and padded values from a dense array.
142 Uses the first batch element (if batched) to detect which diagonals are
143 non-zero. Safe to call outside JIT because *arr* must be a concrete
144 numpy / JAX array.
146 Args:
147 arr: Dense array of shape (*batch, n, n).
149 Returns:
150 Tuple of (offsets, diags) where:
151 - offsets is a sorted tuple of integer offsets
152 - diags is a numpy array of shape (*batch, n_diags, n)
153 """
154 n = arr.shape[-1]
155 batch_shape = arr.shape[:-2]
157 # Union non-zero diagonals across all batch elements via a single mask + nonzero call.
158 arr_np = np.asarray(arr)
159 flat_np = arr_np.reshape(-1, n, n)
160 union_mask = np.any(flat_np != 0, axis=0) # (n, n): True where any batch elem is non-zero
161 r, c = np.nonzero(union_mask)
162 offsets = tuple(sorted(set((c - r).tolist()))) if len(r) > 0 else (0,)
164 diags = np.zeros((*batch_shape, len(offsets), n), dtype=arr_np.dtype)
165 for i, k in enumerate(offsets):
166 # np.diagonal returns shape (*batch, n-|k|)
167 d = np.diagonal(arr_np, offset=k, axis1=-2, axis2=-1)
168 lo = max(k, 0)
169 hi = n - max(-k, 0)
170 diags[..., i, lo:hi] = d
172 return offsets, diags
175# ---------------------------------------------------------------------------
176# SparseDiaImpl
177# ---------------------------------------------------------------------------
179from jaxquantum.core.qarray import QarrayImpl, DenseImpl, QarrayImplType # noqa: E402
182@struct.dataclass
183class SparseDiaImpl(QarrayImpl):
184 """Sparse-diagonal backend storing only diagonal values.
186 Data layout::
188 _offsets : tuple[int, ...] — static (pytree_node=False)
189 _diags : Array[*batch, n_diags, n] — JAX-traced values
191 For offset k, the convention is:
192 * k ≥ 0 : valid data at ``_diags[..., i, k:]``, zeros at ``[0:k]``
193 * k < 0 : valid data at ``_diags[..., i, :n+k]``, zeros at ``[n+k:]``
195 In both cases: ``A[row, row+k] = _diags[..., i, row+k]``
196 """
198 _offsets: tuple = struct.field(pytree_node=False)
199 _diags: Array
201 PROMOTION_ORDER = 0 # noqa: RUF012 — not a struct field
203 # ------------------------------------------------------------------
204 # Construction
205 # ------------------------------------------------------------------
207 @classmethod
208 def from_data(cls, data) -> "SparseDiaImpl":
209 """Wrap *data* in a new ``SparseDiaImpl``.
211 Accepts either a :class:`SparseDiaData` container (direct wrap) or
212 a dense array-like (auto-detect non-zero diagonals via numpy, safe
213 to call before JIT).
215 Args:
216 data: A :class:`SparseDiaData` or dense array of shape
217 (*batch, n, n).
219 Returns:
220 A new ``SparseDiaImpl`` instance.
221 """
222 if isinstance(data, SparseDiaData):
223 return cls(_offsets=data.offsets, _diags=data.diags)
224 offsets, diags_np = _dense_to_sparsedia(np.asarray(data))
225 return cls(_offsets=offsets, _diags=jnp.array(diags_np))
227 @classmethod
228 def from_diags(cls, offsets: tuple, diags: Array) -> "SparseDiaImpl":
229 """Directly construct from sorted offsets and padded diagonal array.
231 This is the preferred factory when diagonal structure is known in
232 advance (e.g., when building ``destroy`` or ``create`` operators).
234 Args:
235 offsets: Tuple of integer diagonal offsets (need not be sorted;
236 will be sorted internally).
237 diags: JAX array of shape (*batch, n_diags, n) with padded
238 diagonal values matching *offsets*.
240 Returns:
241 A new ``SparseDiaImpl`` instance.
242 """
243 return cls(_offsets=tuple(sorted(offsets)), _diags=diags)
245 # ------------------------------------------------------------------
246 # QarrayImpl abstract methods
247 # ------------------------------------------------------------------
249 def get_data(self) -> SparseDiaData:
250 """Return a :class:`SparseDiaData` container with the raw diagonal data."""
251 return SparseDiaData(offsets=self._offsets, diags=self._diags)
253 def shape(self) -> tuple:
254 """Shape of the represented square matrix (including batch dims)."""
255 n = self._diags.shape[-1]
256 return (*self._diags.shape[:-2], n, n)
258 def dtype(self):
259 """Dtype of the stored diagonal values."""
260 return self._diags.dtype
262 def __deepcopy__(self, memo=None):
263 return SparseDiaImpl(
264 _offsets=deepcopy(self._offsets),
265 _diags=self._diags,
266 )
268 # ------------------------------------------------------------------
269 # Arithmetic
270 # ------------------------------------------------------------------
272 def mul(self, scalar) -> "SparseDiaImpl":
273 """Scalar multiplication — scales all diagonal values."""
274 return SparseDiaImpl(_offsets=self._offsets, _diags=scalar * self._diags)
276 def neg(self) -> "SparseDiaImpl":
277 """Negation."""
278 return SparseDiaImpl(_offsets=self._offsets, _diags=-self._diags)
280 def add(self, other: QarrayImpl) -> QarrayImpl:
281 """Element-wise addition.
283 SparseDIA + SparseDIA stays SparseDIA (union of offsets, static).
284 Otherwise coerces to the higher-order type.
285 """
286 if isinstance(other, SparseDiaImpl):
287 return _sparsedia_add(self, other)
288 a, b = self._coerce(other)
289 if a is not self:
290 return a.add(b)
291 return a.add(b)
293 def sub(self, other: QarrayImpl) -> QarrayImpl:
294 """Element-wise subtraction."""
295 if isinstance(other, SparseDiaImpl):
296 return _sparsedia_add(self, other, subtract=True)
297 a, b = self._coerce(other)
298 if a is not self:
299 return a.sub(b)
300 return a.sub(b)
302 def matmul(self, other: QarrayImpl) -> QarrayImpl:
303 """Matrix multiplication.
305 * SparseDIA @ SparseDIA → SparseDIA (O(d₁·d₂·n))
306 * SparseDIA @ Dense → Dense (O(d·n²), no densification of self)
307 * Others → coerce then delegate
308 """
309 if isinstance(other, DenseImpl):
310 return DenseImpl(_sparsedia_matmul_dense(
311 self._offsets, self._diags, other._data
312 ))
313 if isinstance(other, SparseDiaImpl):
314 offsets, diags = _sparsedia_matmul_sparsedia(
315 self._offsets, self._diags,
316 other._offsets, other._diags,
317 )
318 return SparseDiaImpl(_offsets=offsets, _diags=diags)
319 a, b = self._coerce(other)
320 if a is not self:
321 return a.matmul(b)
322 return a.matmul(b)
324 def dag(self) -> "SparseDiaImpl":
325 """Conjugate transpose without densification.
327 Negates every offset and rearranges the stored values so that the
328 padding convention remains consistent.
329 """
330 new_offsets = tuple(-k for k in self._offsets)
331 new_diags = jnp.zeros_like(self._diags)
332 for i, k in enumerate(self._offsets):
333 s = _dia_slice(k) # valid data slice for offset k
334 sm = _dia_slice(-k) # valid data slice for offset -k (the new position)
335 new_diags = new_diags.at[..., i, sm].set(jnp.conj(self._diags[..., i, s]))
336 return SparseDiaImpl(_offsets=new_offsets, _diags=new_diags)
338 def kron(self, other: QarrayImpl) -> QarrayImpl:
339 """Kronecker product.
341 SparseDIA ⊗ SparseDIA stays SparseDIA: output offset for pair
342 (kA, kB) is ``kA * m + kB`` where m = dim(B). Fully vectorised —
343 no loops at JAX level.
344 """
345 if isinstance(other, SparseDiaImpl):
346 return _sparsedia_kron(self, other)
347 a, b = self._coerce(other)
348 if a is not self:
349 return a.kron(b)
350 return a.kron(b)
352 def tidy_up(self, atol) -> "SparseDiaImpl":
353 """Zero diagonal values whose magnitude is below *atol*."""
354 diags = self._diags
355 real_part = jnp.where(jnp.abs(jnp.real(diags)) < atol, 0.0, jnp.real(diags))
356 if jnp.issubdtype(diags.dtype, jnp.complexfloating):
357 imag_part = jnp.where(jnp.abs(jnp.imag(diags)) < atol, 0.0, jnp.imag(diags))
358 new_diags = (real_part + 1j * imag_part).astype(diags.dtype)
359 else:
360 new_diags = real_part.astype(diags.dtype)
361 return SparseDiaImpl(_offsets=self._offsets, _diags=new_diags)
363 # ------------------------------------------------------------------
364 # Conversions
365 # ------------------------------------------------------------------
367 def to_dense(self) -> "DenseImpl":
368 """Convert to a ``DenseImpl`` by summing diagonal contributions."""
369 n = self._diags.shape[-1]
370 batch_shape = self._diags.shape[:-2]
371 result = jnp.zeros((*batch_shape, n, n), dtype=self._diags.dtype)
372 for i, k in enumerate(self._offsets):
373 s = _dia_slice(k)
374 length = n - abs(k)
375 if length <= 0:
376 continue
377 vals = self._diags[..., i, s]
378 row_idx = jnp.arange(length) + max(-k, 0)
379 col_idx = row_idx + k
380 result = result.at[..., row_idx, col_idx].set(vals)
381 return DenseImpl(result)
383 def to_sparse_bcoo(self) -> "SparseBCOOImpl":
384 """Convert to a ``SparseBCOOImpl`` (BCOO) via dense."""
385 return self.to_dense().to_sparse_bcoo()
387 def to_sparse_dia(self) -> "SparseDiaImpl":
388 """Return self (already SparseDIA)."""
389 return self
391 # ------------------------------------------------------------------
392 # Class-method interface
393 # ------------------------------------------------------------------
395 @classmethod
396 def _eye_data(cls, n: int, dtype=None):
397 """Return an n×n identity as a dense JAX array.
399 ``from_data`` will automatically convert it to SparseDIA format
400 when the implementation type is ``SPARSE_DIA``.
401 """
402 return jnp.eye(n, dtype=dtype)
404 @classmethod
405 def can_handle_data(cls, arr) -> bool:
406 """Return True only for :class:`SparseDiaData` objects."""
407 return isinstance(arr, SparseDiaData)
409 @classmethod
410 def dag_data(cls, arr: SparseDiaData) -> SparseDiaData:
411 """Conjugate transpose of raw :class:`SparseDiaData` without densification."""
412 impl = SparseDiaImpl(_offsets=arr.offsets, _diags=arr.diags)
413 result = impl.dag()
414 return result.get_data()
416 # ------------------------------------------------------------------
417 # Extra sparse-native methods (no densification)
418 # ------------------------------------------------------------------
420 def trace(self):
421 """Compute trace directly from the main diagonal (offset 0).
423 Returns:
424 Scalar trace (sum of main diagonal values).
425 """
426 if 0 in self._offsets:
427 i = self._offsets.index(0)
428 return jnp.sum(self._diags[..., i, :], axis=-1)
429 return jnp.zeros(self._diags.shape[:-2], dtype=self._diags.dtype)
431 def frobenius_norm(self):
432 """Frobenius norm computed directly from stored diagonal values."""
433 return jnp.sqrt(jnp.sum(jnp.abs(self._diags) ** 2))
435 def real(self) -> "SparseDiaImpl":
436 """Element-wise real part of stored values."""
437 return SparseDiaImpl(
438 _offsets=self._offsets,
439 _diags=jnp.real(self._diags).astype(self._diags.dtype),
440 )
442 def imag(self) -> "SparseDiaImpl":
443 """Element-wise imaginary part of stored values."""
444 return SparseDiaImpl(
445 _offsets=self._offsets,
446 _diags=jnp.imag(self._diags).astype(self._diags.dtype),
447 )
449 def conj(self) -> "SparseDiaImpl":
450 """Element-wise complex conjugate of stored values."""
451 return SparseDiaImpl(_offsets=self._offsets, _diags=jnp.conj(self._diags))
453 def powm(self, n: int) -> "SparseDiaImpl":
454 """Integer matrix power staying SparseDIA via binary exponentiation.
456 Uses O(log n) SparseDIA @ SparseDIA multiplications rather than
457 densifying. A^0 returns the identity operator.
459 Args:
460 n: Non-negative integer exponent.
462 Returns:
463 A ``SparseDiaImpl`` equal to this matrix raised to the *n*-th power.
465 Raises:
466 ValueError: If *n* is negative.
467 """
468 if n < 0:
469 raise ValueError("powm requires n >= 0")
470 if n == 0:
471 size = self._diags.shape[-1]
472 eye_diags = jnp.ones((*self._diags.shape[:-2], 1, size), dtype=self._diags.dtype)
473 return SparseDiaImpl(_offsets=(0,), _diags=eye_diags)
474 if n == 1:
475 return self
476 half = self.powm(n // 2)
477 squared = half.matmul(half) # SparseDIA @ SparseDIA → SparseDIA
478 return squared if n % 2 == 0 else self.matmul(squared)
481# ---------------------------------------------------------------------------
482# Pure-function helpers (operate on raw arrays, no QarrayImpl wrapping)
483# ---------------------------------------------------------------------------
485def _sparsedia_add(
486 a: SparseDiaImpl,
487 b: SparseDiaImpl,
488 subtract: bool = False,
489) -> SparseDiaImpl:
490 """Add (or subtract) two SparseDiaImpl matrices, preserving SparseDIA format.
492 Computes the union of offsets at Python level (static), then copies /
493 sums the corresponding diagonal arrays.
494 """
495 n = a._diags.shape[-1]
496 batch_shape = jnp.broadcast_shapes(a._diags.shape[:-2], b._diags.shape[:-2])
498 out_offsets = tuple(sorted(set(a._offsets) | set(b._offsets)))
499 out_dtype = jnp.result_type(a._diags.dtype, b._diags.dtype)
500 out_diags = jnp.zeros((*batch_shape, len(out_offsets), n), dtype=out_dtype)
502 a_idx = {k: i for i, k in enumerate(a._offsets)}
503 b_idx = {k: i for i, k in enumerate(b._offsets)}
505 for oi, k in enumerate(out_offsets):
506 val = jnp.zeros((*batch_shape, n), dtype=out_dtype)
507 if k in a_idx:
508 val = val + a._diags[..., a_idx[k], :]
509 if k in b_idx:
510 sign = -1 if subtract else 1
511 val = val + sign * b._diags[..., b_idx[k], :]
512 out_diags = out_diags.at[..., oi, :].set(val)
514 return SparseDiaImpl(_offsets=out_offsets, _diags=out_diags)
517def _sparsedia_matmul_dense(
518 offsets: tuple,
519 diags: Array,
520 B: Array,
521) -> Array:
522 """Compute (SparseDIA) @ (dense matrix) → dense, without densifying the LHS.
524 For each stored diagonal at offset k:
525 result[..., row_range, :] += diags[..., i, valid_slice, None] * B[..., valid_slice, :]
527 Complexity: O(d * n * m) where n×n is the operator and n×m is B.
529 Args:
530 offsets: Static tuple of diagonal offsets for the LHS.
531 diags: JAX array of shape (*batch, n_diags, n).
532 B: Dense right-hand side of shape (*batch, n, m).
534 Returns:
535 Dense product of shape (*batch, n, m).
536 """
537 # n = diags.shape[-1]
538 batch_shape = jnp.broadcast_shapes(diags.shape[:-2], B.shape[:-2])
539 result = jnp.zeros(
540 (*batch_shape, B.shape[-2], B.shape[-1]),
541 dtype=jnp.result_type(diags.dtype, B.dtype),
542 )
543 for i, k in enumerate(offsets):
544 s = _dia_slice(k) # valid column slice for diagonal k
545 sm = _dia_slice(-k) # corresponding row slice for the result
546 result = result.at[..., sm, :].add(diags[..., i, s, None] * B[..., s, :])
547 return result
550def _sparsedia_rmatmul_dense(
551 B: Array,
552 offsets: tuple,
553 diags: Array,
554) -> Array:
555 """Compute (dense matrix) @ (SparseDIA) → dense.
557 For each stored diagonal at offset k:
558 C[..., :, k:] += B[..., :, :n-k] * diags[..., i, k:][..., None, :] (k ≥ 0)
559 C[..., :, :n-m] += B[..., :, m:] * diags[..., i, :n-m][..., None, :] (k < 0)
561 Complexity: O(d * n * p) where n×n is the operator and p×n is B.
563 Args:
564 B: Dense left-hand side of shape (*batch, p, n).
565 offsets: Static tuple of diagonal offsets for the RHS.
566 diags: JAX array of shape (*batch, n_diags, n).
568 Returns:
569 Dense product of shape (*batch, p, n).
570 """
571 # n = diags.shape[-1]
572 batch_shape = jnp.broadcast_shapes(diags.shape[:-2], B.shape[:-2])
573 result = jnp.zeros(
574 (*batch_shape, B.shape[-2], B.shape[-1]),
575 dtype=jnp.result_type(diags.dtype, B.dtype),
576 )
577 for i, k in enumerate(offsets):
578 s = _dia_slice(k) # valid column slice for diagonal k
579 sm = _dia_slice(-k) # complementary slice for B columns / result columns
580 result = result.at[..., :, s].add(B[..., :, sm] * diags[..., i, s][..., None, :])
581 return result
584def _sparsedia_matmul_sparsedia(
585 left_offsets: tuple,
586 left_diags: Array,
587 right_offsets: tuple,
588 right_diags: Array,
589) -> tuple[tuple, Array]:
590 """Compute (SparseDIA) @ (SparseDIA) → SparseDIA.
592 Derivation uses the unified access formula A[i, i+k] = diags[i+k].
593 For the contribution of diagonal pair (k1, k2) to output at kout = k1+k2:
595 out_diag[j + k2] += left_diag[j] * right_diag[j + k2]
597 This is an aligned slice multiply, handled by the sign of k2:
598 k2 ≥ 0: out.at[kout:].add( left[:n-k2] * right[k2:] )
599 k2 < 0: out.at[:n+k2].add( left[-k2:] * right[:n+k2] )
601 Complexity: O(d1 * d2 * n).
603 Args:
604 left_offsets: Static offsets for the LHS matrix.
605 left_diags: JAX array (*batch, d1, n).
606 right_offsets: Static offsets for the RHS matrix.
607 right_diags: JAX array (*batch, d2, n).
609 Returns:
610 Tuple of (out_offsets, out_diags).
611 """
612 n = left_diags.shape[-1]
613 batch_shape = jnp.broadcast_shapes(
614 left_diags.shape[:-2], right_diags.shape[:-2]
615 )
617 # Pre-filter output offsets: diagonal pairs where |k1+k2| >= n are zero.
618 out_offset_set = sorted(
619 {k1 + k2 for k1 in left_offsets for k2 in right_offsets if abs(k1 + k2) < n}
620 )
621 if not out_offset_set:
622 out_offset_set = [0]
623 out_offset_idx = {k: i for i, k in enumerate(out_offset_set)}
624 out_diags = jnp.zeros(
625 (*batch_shape, len(out_offset_set), n),
626 dtype=jnp.result_type(left_diags.dtype, right_diags.dtype),
627 )
629 for li, k1 in enumerate(left_offsets):
630 for ri, k2 in enumerate(right_offsets):
631 kout = k1 + k2
632 if abs(kout) >= n:
633 continue
634 oi = out_offset_idx[kout]
635 s = _dia_slice(k2) # valid column slice for right diagonal k2
636 sm = _dia_slice(-k2) # complementary slice for left diagonal
637 contribution = left_diags[..., li, sm] * right_diags[..., ri, s]
638 out_diags = out_diags.at[..., oi, s].add(contribution)
640 return tuple(out_offset_set), out_diags
643def _sparsedia_kron(a: SparseDiaImpl, b: SparseDiaImpl) -> SparseDiaImpl:
644 """Kronecker product of two SparseDiaImpl matrices → SparseDiaImpl.
646 For operands A (n_A × n_A) and B (m × m), the output has dimension
647 (n_A*m) × (n_A*m). Each diagonal pair (kA, kB) contributes to the
648 output diagonal at offset ``kout = kA * m + kB``.
650 Key insight — the full output diagonal can be constructed without any
651 scatter or dynamic indexing::
653 out_diag_arr[s] = left_padded[ s // m ] * right_padded[ s % m ]
655 which equals::
657 jnp.repeat(left_padded, m, axis=-1) * jnp.tile(right_padded, n_A)
659 Complexity: O(d_A * d_B * N) where N = n_A * m.
661 Args:
662 a: Left SparseDiaImpl of shape (*batch, n_A, n_A).
663 b: Right SparseDiaImpl of shape (*batch, m, m).
665 Returns:
666 SparseDiaImpl of shape (*batch, N, N).
667 """
668 n_A = a._diags.shape[-1]
669 m = b._diags.shape[-1]
671 # N = n_A * m
672 # batch_shape = jnp.broadcast_shapes(a._diags.shape[:-2], b._diags.shape[:-2])
674 # Accumulate contributions per output offset
675 out_accum: dict[int, Array] = {}
677 for li, kA in enumerate(a._offsets):
678 for ri, kB in enumerate(b._offsets):
679 kout = kA * m + kB
680 # Full output diagonal (length N) via repeat/tile — fully vectorised
681 left_rep = jnp.repeat(a._diags[..., li, :], m, axis=-1) # (*batch, N)
682 right_tiled = jnp.tile(b._diags[..., ri, :], n_A) # (*batch, N)
683 contrib = left_rep * right_tiled
684 if kout in out_accum:
685 out_accum[kout] = out_accum[kout] + contrib
686 else:
687 out_accum[kout] = contrib
689 out_offsets = tuple(sorted(out_accum.keys()))
690 out_diags = jnp.stack([out_accum[k] for k in out_offsets], axis=-2)
691 return SparseDiaImpl(_offsets=out_offsets, _diags=out_diags)
694# ---------------------------------------------------------------------------
695# Register with the QarrayImplType enum
696# ---------------------------------------------------------------------------
698QarrayImplType.register(SparseDiaImpl, QarrayImplType.SPARSE_DIA)
701# ---------------------------------------------------------------------------
702# Public exports
703# ---------------------------------------------------------------------------
705__all__ = [
706 "SparseDiaData",
707 "SparseDiaImpl",
708]