Coverage for jaxquantum / core / sparse_dia.py: 92%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""Sparse diagonal (SparseDIA) backend for Qarray. 

2 

3Stores only the diagonal *values* of a matrix, making quantum operators with 

4small numbers of non-zero diagonals (annihilation, creation, number, Kerr…) 

5far cheaper than Dense or BCOO: 

6 

7 * Memory: O(d * n) where d = number of stored diagonals, n = matrix size 

8 * No index arrays (unlike BCOO which stores (row, col) per non-zero) 

9 * ``_offsets`` is *static* Python metadata (pytree_node=False), so JAX 

10 unrolls all loops over diagonals at compile time — only static slices, 

11 no dynamic indexing or scatter/gather. 

12 

13Padding convention (Convention A): 

14 

15 For diagonal at offset k (k ≥ 0): 

16 diags[..., i, j] = A[j-k, j] for j ∈ [k, n-1]; zeros at [0:k] 

17 For diagonal at offset k (k < 0): 

18 diags[..., i, j] = A[j-k, j] for j ∈ [0, n+k-1]; zeros at [n+k:] 

19 

20Unified access formula (holds for any k, out-of-range slots are zero): 

21 A[i, i+k] = diags[..., diag_idx, i+k] 

22 

23This makes every matrix operation a set of aligned slice multiplications. 

24 

25Some improvements (_dia_slice helper, integer matrix power, diagonal-range pruning, 

26offset detection) were identified by studying the dynamiqs library 

27(https://github.com/dynamiqs/dynamiqs). 

28""" 

29 

30from __future__ import annotations 

31 

32import numpy as np 

33from copy import deepcopy 

34from typing import TYPE_CHECKING 

35 

36import jax.numpy as jnp 

37from flax import struct 

38from jax import Array 

39 

40if TYPE_CHECKING: 

41 from jaxquantum.core.qarray import DenseImpl, QarrayImplType 

42 from jaxquantum.core.sparse_bcoo import SparseBCOOImpl 

43 

44 

45# --------------------------------------------------------------------------- 

46# Slice helper 

47# --------------------------------------------------------------------------- 

48 

49def _dia_slice(k: int) -> slice: 

50 """Slice selecting the valid data positions for diagonal offset k. 

51 

52 For k ≥ 0: valid data lives at column indices [k, n), so slice(k, None). 

53 For k < 0: valid data lives at column indices [0, n+k), so slice(None, k). 

54 

55 The complementary slice (for the result rows / left-operand columns) is 

56 always ``_dia_slice(-k)``. 

57 """ 

58 return slice(k, None) if k >= 0 else slice(None, k) 

59 

60 

61# --------------------------------------------------------------------------- 

62# Raw data container 

63# --------------------------------------------------------------------------- 

64 

65@struct.dataclass 

66class SparseDiaData: 

67 """Lightweight pytree-compatible container for sparse-diagonal raw data. 

68 

69 Returned by ``SparseDiaImpl.get_data()`` and consumed by 

70 ``SparseDiaImpl.from_data()``. Registered as a JAX pytree via Flax's 

71 ``@struct.dataclass``; ``offsets`` is *not* a pytree leaf (it is static 

72 compile-time metadata). 

73 

74 Attributes: 

75 offsets: Static tuple of diagonal offsets (pytree_node=False). 

76 diags: JAX array of shape (*batch, n_diags, n) containing the 

77 padded diagonal values. 

78 """ 

79 

80 offsets: tuple = struct.field(pytree_node=False) 

81 diags: Array 

82 

83 # Class-level marker (not a dataclass field — no type annotation). 

84 # Used by DenseImpl.can_handle_data to exclude SparseDiaData without 

85 # a direct import (which would be circular). 

86 _is_sparse_dia = True 

87 

88 @property 

89 def shape(self) -> tuple: 

90 """Shape of the represented square matrix (*batch, n, n).""" 

91 n = self.diags.shape[-1] 

92 return (*self.diags.shape[:-2], n, n) 

93 

94 @property 

95 def dtype(self): 

96 """Dtype of the stored diagonal values.""" 

97 return self.diags.dtype 

98 

99 def __mul__(self, scalar): 

100 return SparseDiaData(offsets=self.offsets, diags=self.diags * scalar) 

101 

102 def __rmul__(self, scalar): 

103 return SparseDiaData(offsets=self.offsets, diags=scalar * self.diags) 

104 

105 def __getitem__(self, index): 

106 """Index into the batch dimension(s), preserving offsets.""" 

107 return SparseDiaData(offsets=self.offsets, diags=self.diags[index]) 

108 

109 def __len__(self): 

110 """Number of elements along the leading batch dimension.""" 

111 return self.shape[0] 

112 

113 def reshape(self, *new_shape): 

114 """Reshape batch dimensions while preserving diagonal structure. 

115 

116 ``new_shape`` must end with ``(N, N)`` (the matrix dims are unchanged). 

117 Only the leading batch dims are reshaped. 

118 """ 

119 new_batch = new_shape[:-2] 

120 n = self.diags.shape[-1] 

121 new_diags = self.diags.reshape(*new_batch, len(self.offsets), n) 

122 return SparseDiaData(offsets=self.offsets, diags=new_diags) 

123 

124 def __matmul__(self, other): 

125 """SparseDIA @ dense → dense (used by mesolve ODE RHS).""" 

126 # _sparsedia_matmul_dense is defined later in this module; Python 

127 # resolves the name at call time so forward reference is fine. 

128 return _sparsedia_matmul_dense(self.offsets, self.diags, other) 

129 

130 def __rmatmul__(self, other): 

131 """dense @ SparseDIA → dense (used by mesolve ODE RHS).""" 

132 return _sparsedia_rmatmul_dense(other, self.offsets, self.diags) 

133 

134 

135# --------------------------------------------------------------------------- 

136# Helper: dense → SparseDIA conversion 

137# --------------------------------------------------------------------------- 

138 

139def _dense_to_sparsedia(arr: np.ndarray) -> tuple[tuple, np.ndarray]: 

140 """Extract non-zero diagonal offsets and padded values from a dense array. 

141 

142 Uses the first batch element (if batched) to detect which diagonals are 

143 non-zero. Safe to call outside JIT because *arr* must be a concrete 

144 numpy / JAX array. 

145 

146 Args: 

147 arr: Dense array of shape (*batch, n, n). 

148 

149 Returns: 

150 Tuple of (offsets, diags) where: 

151 - offsets is a sorted tuple of integer offsets 

152 - diags is a numpy array of shape (*batch, n_diags, n) 

153 """ 

154 n = arr.shape[-1] 

155 batch_shape = arr.shape[:-2] 

156 

157 # Union non-zero diagonals across all batch elements via a single mask + nonzero call. 

158 arr_np = np.asarray(arr) 

159 flat_np = arr_np.reshape(-1, n, n) 

160 union_mask = np.any(flat_np != 0, axis=0) # (n, n): True where any batch elem is non-zero 

161 r, c = np.nonzero(union_mask) 

162 offsets = tuple(sorted(set((c - r).tolist()))) if len(r) > 0 else (0,) 

163 

164 diags = np.zeros((*batch_shape, len(offsets), n), dtype=arr_np.dtype) 

165 for i, k in enumerate(offsets): 

166 # np.diagonal returns shape (*batch, n-|k|) 

167 d = np.diagonal(arr_np, offset=k, axis1=-2, axis2=-1) 

168 lo = max(k, 0) 

169 hi = n - max(-k, 0) 

170 diags[..., i, lo:hi] = d 

171 

172 return offsets, diags 

173 

174 

175# --------------------------------------------------------------------------- 

176# SparseDiaImpl 

177# --------------------------------------------------------------------------- 

178 

179from jaxquantum.core.qarray import QarrayImpl, DenseImpl, QarrayImplType # noqa: E402 

180 

181 

182@struct.dataclass 

183class SparseDiaImpl(QarrayImpl): 

184 """Sparse-diagonal backend storing only diagonal values. 

185 

186 Data layout:: 

187 

188 _offsets : tuple[int, ...] — static (pytree_node=False) 

189 _diags : Array[*batch, n_diags, n] — JAX-traced values 

190 

191 For offset k, the convention is: 

192 * k ≥ 0 : valid data at ``_diags[..., i, k:]``, zeros at ``[0:k]`` 

193 * k < 0 : valid data at ``_diags[..., i, :n+k]``, zeros at ``[n+k:]`` 

194 

195 In both cases: ``A[row, row+k] = _diags[..., i, row+k]`` 

196 """ 

197 

198 _offsets: tuple = struct.field(pytree_node=False) 

199 _diags: Array 

200 

201 PROMOTION_ORDER = 0 # noqa: RUF012 — not a struct field 

202 

203 # ------------------------------------------------------------------ 

204 # Construction 

205 # ------------------------------------------------------------------ 

206 

207 @classmethod 

208 def from_data(cls, data) -> "SparseDiaImpl": 

209 """Wrap *data* in a new ``SparseDiaImpl``. 

210 

211 Accepts either a :class:`SparseDiaData` container (direct wrap) or 

212 a dense array-like (auto-detect non-zero diagonals via numpy, safe 

213 to call before JIT). 

214 

215 Args: 

216 data: A :class:`SparseDiaData` or dense array of shape 

217 (*batch, n, n). 

218 

219 Returns: 

220 A new ``SparseDiaImpl`` instance. 

221 """ 

222 if isinstance(data, SparseDiaData): 

223 return cls(_offsets=data.offsets, _diags=data.diags) 

224 offsets, diags_np = _dense_to_sparsedia(np.asarray(data)) 

225 return cls(_offsets=offsets, _diags=jnp.array(diags_np)) 

226 

227 @classmethod 

228 def from_diags(cls, offsets: tuple, diags: Array) -> "SparseDiaImpl": 

229 """Directly construct from sorted offsets and padded diagonal array. 

230 

231 This is the preferred factory when diagonal structure is known in 

232 advance (e.g., when building ``destroy`` or ``create`` operators). 

233 

234 Args: 

235 offsets: Tuple of integer diagonal offsets (need not be sorted; 

236 will be sorted internally). 

237 diags: JAX array of shape (*batch, n_diags, n) with padded 

238 diagonal values matching *offsets*. 

239 

240 Returns: 

241 A new ``SparseDiaImpl`` instance. 

242 """ 

243 return cls(_offsets=tuple(sorted(offsets)), _diags=diags) 

244 

245 # ------------------------------------------------------------------ 

246 # QarrayImpl abstract methods 

247 # ------------------------------------------------------------------ 

248 

249 def get_data(self) -> SparseDiaData: 

250 """Return a :class:`SparseDiaData` container with the raw diagonal data.""" 

251 return SparseDiaData(offsets=self._offsets, diags=self._diags) 

252 

253 def shape(self) -> tuple: 

254 """Shape of the represented square matrix (including batch dims).""" 

255 n = self._diags.shape[-1] 

256 return (*self._diags.shape[:-2], n, n) 

257 

258 def dtype(self): 

259 """Dtype of the stored diagonal values.""" 

260 return self._diags.dtype 

261 

262 def __deepcopy__(self, memo=None): 

263 return SparseDiaImpl( 

264 _offsets=deepcopy(self._offsets), 

265 _diags=self._diags, 

266 ) 

267 

268 # ------------------------------------------------------------------ 

269 # Arithmetic 

270 # ------------------------------------------------------------------ 

271 

272 def mul(self, scalar) -> "SparseDiaImpl": 

273 """Scalar multiplication — scales all diagonal values.""" 

274 return SparseDiaImpl(_offsets=self._offsets, _diags=scalar * self._diags) 

275 

276 def neg(self) -> "SparseDiaImpl": 

277 """Negation.""" 

278 return SparseDiaImpl(_offsets=self._offsets, _diags=-self._diags) 

279 

280 def add(self, other: QarrayImpl) -> QarrayImpl: 

281 """Element-wise addition. 

282 

283 SparseDIA + SparseDIA stays SparseDIA (union of offsets, static). 

284 Otherwise coerces to the higher-order type. 

285 """ 

286 if isinstance(other, SparseDiaImpl): 

287 return _sparsedia_add(self, other) 

288 a, b = self._coerce(other) 

289 if a is not self: 

290 return a.add(b) 

291 return a.add(b) 

292 

293 def sub(self, other: QarrayImpl) -> QarrayImpl: 

294 """Element-wise subtraction.""" 

295 if isinstance(other, SparseDiaImpl): 

296 return _sparsedia_add(self, other, subtract=True) 

297 a, b = self._coerce(other) 

298 if a is not self: 

299 return a.sub(b) 

300 return a.sub(b) 

301 

302 def matmul(self, other: QarrayImpl) -> QarrayImpl: 

303 """Matrix multiplication. 

304 

305 * SparseDIA @ SparseDIA → SparseDIA (O(d₁·d₂·n)) 

306 * SparseDIA @ Dense → Dense (O(d·n²), no densification of self) 

307 * Others → coerce then delegate 

308 """ 

309 if isinstance(other, DenseImpl): 

310 return DenseImpl(_sparsedia_matmul_dense( 

311 self._offsets, self._diags, other._data 

312 )) 

313 if isinstance(other, SparseDiaImpl): 

314 offsets, diags = _sparsedia_matmul_sparsedia( 

315 self._offsets, self._diags, 

316 other._offsets, other._diags, 

317 ) 

318 return SparseDiaImpl(_offsets=offsets, _diags=diags) 

319 a, b = self._coerce(other) 

320 if a is not self: 

321 return a.matmul(b) 

322 return a.matmul(b) 

323 

324 def dag(self) -> "SparseDiaImpl": 

325 """Conjugate transpose without densification. 

326 

327 Negates every offset and rearranges the stored values so that the 

328 padding convention remains consistent. 

329 """ 

330 new_offsets = tuple(-k for k in self._offsets) 

331 new_diags = jnp.zeros_like(self._diags) 

332 for i, k in enumerate(self._offsets): 

333 s = _dia_slice(k) # valid data slice for offset k 

334 sm = _dia_slice(-k) # valid data slice for offset -k (the new position) 

335 new_diags = new_diags.at[..., i, sm].set(jnp.conj(self._diags[..., i, s])) 

336 return SparseDiaImpl(_offsets=new_offsets, _diags=new_diags) 

337 

338 def kron(self, other: QarrayImpl) -> QarrayImpl: 

339 """Kronecker product. 

340 

341 SparseDIA ⊗ SparseDIA stays SparseDIA: output offset for pair 

342 (kA, kB) is ``kA * m + kB`` where m = dim(B). Fully vectorised — 

343 no loops at JAX level. 

344 """ 

345 if isinstance(other, SparseDiaImpl): 

346 return _sparsedia_kron(self, other) 

347 a, b = self._coerce(other) 

348 if a is not self: 

349 return a.kron(b) 

350 return a.kron(b) 

351 

352 def tidy_up(self, atol) -> "SparseDiaImpl": 

353 """Zero diagonal values whose magnitude is below *atol*.""" 

354 diags = self._diags 

355 real_part = jnp.where(jnp.abs(jnp.real(diags)) < atol, 0.0, jnp.real(diags)) 

356 if jnp.issubdtype(diags.dtype, jnp.complexfloating): 

357 imag_part = jnp.where(jnp.abs(jnp.imag(diags)) < atol, 0.0, jnp.imag(diags)) 

358 new_diags = (real_part + 1j * imag_part).astype(diags.dtype) 

359 else: 

360 new_diags = real_part.astype(diags.dtype) 

361 return SparseDiaImpl(_offsets=self._offsets, _diags=new_diags) 

362 

363 # ------------------------------------------------------------------ 

364 # Conversions 

365 # ------------------------------------------------------------------ 

366 

367 def to_dense(self) -> "DenseImpl": 

368 """Convert to a ``DenseImpl`` by summing diagonal contributions.""" 

369 n = self._diags.shape[-1] 

370 batch_shape = self._diags.shape[:-2] 

371 result = jnp.zeros((*batch_shape, n, n), dtype=self._diags.dtype) 

372 for i, k in enumerate(self._offsets): 

373 s = _dia_slice(k) 

374 length = n - abs(k) 

375 if length <= 0: 

376 continue 

377 vals = self._diags[..., i, s] 

378 row_idx = jnp.arange(length) + max(-k, 0) 

379 col_idx = row_idx + k 

380 result = result.at[..., row_idx, col_idx].set(vals) 

381 return DenseImpl(result) 

382 

383 def to_sparse_bcoo(self) -> "SparseBCOOImpl": 

384 """Convert to a ``SparseBCOOImpl`` (BCOO) via dense.""" 

385 return self.to_dense().to_sparse_bcoo() 

386 

387 def to_sparse_dia(self) -> "SparseDiaImpl": 

388 """Return self (already SparseDIA).""" 

389 return self 

390 

391 # ------------------------------------------------------------------ 

392 # Class-method interface 

393 # ------------------------------------------------------------------ 

394 

395 @classmethod 

396 def _eye_data(cls, n: int, dtype=None): 

397 """Return an n×n identity as a dense JAX array. 

398 

399 ``from_data`` will automatically convert it to SparseDIA format 

400 when the implementation type is ``SPARSE_DIA``. 

401 """ 

402 return jnp.eye(n, dtype=dtype) 

403 

404 @classmethod 

405 def can_handle_data(cls, arr) -> bool: 

406 """Return True only for :class:`SparseDiaData` objects.""" 

407 return isinstance(arr, SparseDiaData) 

408 

409 @classmethod 

410 def dag_data(cls, arr: SparseDiaData) -> SparseDiaData: 

411 """Conjugate transpose of raw :class:`SparseDiaData` without densification.""" 

412 impl = SparseDiaImpl(_offsets=arr.offsets, _diags=arr.diags) 

413 result = impl.dag() 

414 return result.get_data() 

415 

416 # ------------------------------------------------------------------ 

417 # Extra sparse-native methods (no densification) 

418 # ------------------------------------------------------------------ 

419 

420 def trace(self): 

421 """Compute trace directly from the main diagonal (offset 0). 

422 

423 Returns: 

424 Scalar trace (sum of main diagonal values). 

425 """ 

426 if 0 in self._offsets: 

427 i = self._offsets.index(0) 

428 return jnp.sum(self._diags[..., i, :], axis=-1) 

429 return jnp.zeros(self._diags.shape[:-2], dtype=self._diags.dtype) 

430 

431 def frobenius_norm(self): 

432 """Frobenius norm computed directly from stored diagonal values.""" 

433 return jnp.sqrt(jnp.sum(jnp.abs(self._diags) ** 2)) 

434 

435 def real(self) -> "SparseDiaImpl": 

436 """Element-wise real part of stored values.""" 

437 return SparseDiaImpl( 

438 _offsets=self._offsets, 

439 _diags=jnp.real(self._diags).astype(self._diags.dtype), 

440 ) 

441 

442 def imag(self) -> "SparseDiaImpl": 

443 """Element-wise imaginary part of stored values.""" 

444 return SparseDiaImpl( 

445 _offsets=self._offsets, 

446 _diags=jnp.imag(self._diags).astype(self._diags.dtype), 

447 ) 

448 

449 def conj(self) -> "SparseDiaImpl": 

450 """Element-wise complex conjugate of stored values.""" 

451 return SparseDiaImpl(_offsets=self._offsets, _diags=jnp.conj(self._diags)) 

452 

453 def powm(self, n: int) -> "SparseDiaImpl": 

454 """Integer matrix power staying SparseDIA via binary exponentiation. 

455 

456 Uses O(log n) SparseDIA @ SparseDIA multiplications rather than 

457 densifying. A^0 returns the identity operator. 

458 

459 Args: 

460 n: Non-negative integer exponent. 

461 

462 Returns: 

463 A ``SparseDiaImpl`` equal to this matrix raised to the *n*-th power. 

464 

465 Raises: 

466 ValueError: If *n* is negative. 

467 """ 

468 if n < 0: 

469 raise ValueError("powm requires n >= 0") 

470 if n == 0: 

471 size = self._diags.shape[-1] 

472 eye_diags = jnp.ones((*self._diags.shape[:-2], 1, size), dtype=self._diags.dtype) 

473 return SparseDiaImpl(_offsets=(0,), _diags=eye_diags) 

474 if n == 1: 

475 return self 

476 half = self.powm(n // 2) 

477 squared = half.matmul(half) # SparseDIA @ SparseDIA → SparseDIA 

478 return squared if n % 2 == 0 else self.matmul(squared) 

479 

480 

481# --------------------------------------------------------------------------- 

482# Pure-function helpers (operate on raw arrays, no QarrayImpl wrapping) 

483# --------------------------------------------------------------------------- 

484 

485def _sparsedia_add( 

486 a: SparseDiaImpl, 

487 b: SparseDiaImpl, 

488 subtract: bool = False, 

489) -> SparseDiaImpl: 

490 """Add (or subtract) two SparseDiaImpl matrices, preserving SparseDIA format. 

491 

492 Computes the union of offsets at Python level (static), then copies / 

493 sums the corresponding diagonal arrays. 

494 """ 

495 n = a._diags.shape[-1] 

496 batch_shape = jnp.broadcast_shapes(a._diags.shape[:-2], b._diags.shape[:-2]) 

497 

498 out_offsets = tuple(sorted(set(a._offsets) | set(b._offsets))) 

499 out_dtype = jnp.result_type(a._diags.dtype, b._diags.dtype) 

500 out_diags = jnp.zeros((*batch_shape, len(out_offsets), n), dtype=out_dtype) 

501 

502 a_idx = {k: i for i, k in enumerate(a._offsets)} 

503 b_idx = {k: i for i, k in enumerate(b._offsets)} 

504 

505 for oi, k in enumerate(out_offsets): 

506 val = jnp.zeros((*batch_shape, n), dtype=out_dtype) 

507 if k in a_idx: 

508 val = val + a._diags[..., a_idx[k], :] 

509 if k in b_idx: 

510 sign = -1 if subtract else 1 

511 val = val + sign * b._diags[..., b_idx[k], :] 

512 out_diags = out_diags.at[..., oi, :].set(val) 

513 

514 return SparseDiaImpl(_offsets=out_offsets, _diags=out_diags) 

515 

516 

517def _sparsedia_matmul_dense( 

518 offsets: tuple, 

519 diags: Array, 

520 B: Array, 

521) -> Array: 

522 """Compute (SparseDIA) @ (dense matrix) → dense, without densifying the LHS. 

523 

524 For each stored diagonal at offset k: 

525 result[..., row_range, :] += diags[..., i, valid_slice, None] * B[..., valid_slice, :] 

526 

527 Complexity: O(d * n * m) where n×n is the operator and n×m is B. 

528 

529 Args: 

530 offsets: Static tuple of diagonal offsets for the LHS. 

531 diags: JAX array of shape (*batch, n_diags, n). 

532 B: Dense right-hand side of shape (*batch, n, m). 

533 

534 Returns: 

535 Dense product of shape (*batch, n, m). 

536 """ 

537 # n = diags.shape[-1] 

538 batch_shape = jnp.broadcast_shapes(diags.shape[:-2], B.shape[:-2]) 

539 result = jnp.zeros( 

540 (*batch_shape, B.shape[-2], B.shape[-1]), 

541 dtype=jnp.result_type(diags.dtype, B.dtype), 

542 ) 

543 for i, k in enumerate(offsets): 

544 s = _dia_slice(k) # valid column slice for diagonal k 

545 sm = _dia_slice(-k) # corresponding row slice for the result 

546 result = result.at[..., sm, :].add(diags[..., i, s, None] * B[..., s, :]) 

547 return result 

548 

549 

550def _sparsedia_rmatmul_dense( 

551 B: Array, 

552 offsets: tuple, 

553 diags: Array, 

554) -> Array: 

555 """Compute (dense matrix) @ (SparseDIA) → dense. 

556 

557 For each stored diagonal at offset k: 

558 C[..., :, k:] += B[..., :, :n-k] * diags[..., i, k:][..., None, :] (k ≥ 0) 

559 C[..., :, :n-m] += B[..., :, m:] * diags[..., i, :n-m][..., None, :] (k < 0) 

560 

561 Complexity: O(d * n * p) where n×n is the operator and p×n is B. 

562 

563 Args: 

564 B: Dense left-hand side of shape (*batch, p, n). 

565 offsets: Static tuple of diagonal offsets for the RHS. 

566 diags: JAX array of shape (*batch, n_diags, n). 

567 

568 Returns: 

569 Dense product of shape (*batch, p, n). 

570 """ 

571 # n = diags.shape[-1] 

572 batch_shape = jnp.broadcast_shapes(diags.shape[:-2], B.shape[:-2]) 

573 result = jnp.zeros( 

574 (*batch_shape, B.shape[-2], B.shape[-1]), 

575 dtype=jnp.result_type(diags.dtype, B.dtype), 

576 ) 

577 for i, k in enumerate(offsets): 

578 s = _dia_slice(k) # valid column slice for diagonal k 

579 sm = _dia_slice(-k) # complementary slice for B columns / result columns 

580 result = result.at[..., :, s].add(B[..., :, sm] * diags[..., i, s][..., None, :]) 

581 return result 

582 

583 

584def _sparsedia_matmul_sparsedia( 

585 left_offsets: tuple, 

586 left_diags: Array, 

587 right_offsets: tuple, 

588 right_diags: Array, 

589) -> tuple[tuple, Array]: 

590 """Compute (SparseDIA) @ (SparseDIA) → SparseDIA. 

591 

592 Derivation uses the unified access formula A[i, i+k] = diags[i+k]. 

593 For the contribution of diagonal pair (k1, k2) to output at kout = k1+k2: 

594 

595 out_diag[j + k2] += left_diag[j] * right_diag[j + k2] 

596 

597 This is an aligned slice multiply, handled by the sign of k2: 

598 k2 ≥ 0: out.at[kout:].add( left[:n-k2] * right[k2:] ) 

599 k2 < 0: out.at[:n+k2].add( left[-k2:] * right[:n+k2] ) 

600 

601 Complexity: O(d1 * d2 * n). 

602 

603 Args: 

604 left_offsets: Static offsets for the LHS matrix. 

605 left_diags: JAX array (*batch, d1, n). 

606 right_offsets: Static offsets for the RHS matrix. 

607 right_diags: JAX array (*batch, d2, n). 

608 

609 Returns: 

610 Tuple of (out_offsets, out_diags). 

611 """ 

612 n = left_diags.shape[-1] 

613 batch_shape = jnp.broadcast_shapes( 

614 left_diags.shape[:-2], right_diags.shape[:-2] 

615 ) 

616 

617 # Pre-filter output offsets: diagonal pairs where |k1+k2| >= n are zero. 

618 out_offset_set = sorted( 

619 {k1 + k2 for k1 in left_offsets for k2 in right_offsets if abs(k1 + k2) < n} 

620 ) 

621 if not out_offset_set: 

622 out_offset_set = [0] 

623 out_offset_idx = {k: i for i, k in enumerate(out_offset_set)} 

624 out_diags = jnp.zeros( 

625 (*batch_shape, len(out_offset_set), n), 

626 dtype=jnp.result_type(left_diags.dtype, right_diags.dtype), 

627 ) 

628 

629 for li, k1 in enumerate(left_offsets): 

630 for ri, k2 in enumerate(right_offsets): 

631 kout = k1 + k2 

632 if abs(kout) >= n: 

633 continue 

634 oi = out_offset_idx[kout] 

635 s = _dia_slice(k2) # valid column slice for right diagonal k2 

636 sm = _dia_slice(-k2) # complementary slice for left diagonal 

637 contribution = left_diags[..., li, sm] * right_diags[..., ri, s] 

638 out_diags = out_diags.at[..., oi, s].add(contribution) 

639 

640 return tuple(out_offset_set), out_diags 

641 

642 

643def _sparsedia_kron(a: SparseDiaImpl, b: SparseDiaImpl) -> SparseDiaImpl: 

644 """Kronecker product of two SparseDiaImpl matrices → SparseDiaImpl. 

645 

646 For operands A (n_A × n_A) and B (m × m), the output has dimension 

647 (n_A*m) × (n_A*m). Each diagonal pair (kA, kB) contributes to the 

648 output diagonal at offset ``kout = kA * m + kB``. 

649 

650 Key insight — the full output diagonal can be constructed without any 

651 scatter or dynamic indexing:: 

652 

653 out_diag_arr[s] = left_padded[ s // m ] * right_padded[ s % m ] 

654 

655 which equals:: 

656 

657 jnp.repeat(left_padded, m, axis=-1) * jnp.tile(right_padded, n_A) 

658 

659 Complexity: O(d_A * d_B * N) where N = n_A * m. 

660 

661 Args: 

662 a: Left SparseDiaImpl of shape (*batch, n_A, n_A). 

663 b: Right SparseDiaImpl of shape (*batch, m, m). 

664 

665 Returns: 

666 SparseDiaImpl of shape (*batch, N, N). 

667 """ 

668 n_A = a._diags.shape[-1] 

669 m = b._diags.shape[-1] 

670 

671 # N = n_A * m 

672 # batch_shape = jnp.broadcast_shapes(a._diags.shape[:-2], b._diags.shape[:-2]) 

673 

674 # Accumulate contributions per output offset 

675 out_accum: dict[int, Array] = {} 

676 

677 for li, kA in enumerate(a._offsets): 

678 for ri, kB in enumerate(b._offsets): 

679 kout = kA * m + kB 

680 # Full output diagonal (length N) via repeat/tile — fully vectorised 

681 left_rep = jnp.repeat(a._diags[..., li, :], m, axis=-1) # (*batch, N) 

682 right_tiled = jnp.tile(b._diags[..., ri, :], n_A) # (*batch, N) 

683 contrib = left_rep * right_tiled 

684 if kout in out_accum: 

685 out_accum[kout] = out_accum[kout] + contrib 

686 else: 

687 out_accum[kout] = contrib 

688 

689 out_offsets = tuple(sorted(out_accum.keys())) 

690 out_diags = jnp.stack([out_accum[k] for k in out_offsets], axis=-2) 

691 return SparseDiaImpl(_offsets=out_offsets, _diags=out_diags) 

692 

693 

694# --------------------------------------------------------------------------- 

695# Register with the QarrayImplType enum 

696# --------------------------------------------------------------------------- 

697 

698QarrayImplType.register(SparseDiaImpl, QarrayImplType.SPARSE_DIA) 

699 

700 

701# --------------------------------------------------------------------------- 

702# Public exports 

703# --------------------------------------------------------------------------- 

704 

705__all__ = [ 

706 "SparseDiaData", 

707 "SparseDiaImpl", 

708]