sparse_dia
Sparse diagonal (SparseDIA) backend for Qarray.
Stores only the diagonal values of a matrix, making quantum operators with small numbers of non-zero diagonals (annihilation, creation, number, Kerr…) far cheaper than Dense or BCOO:
- Memory: O(d * n) where d = number of stored diagonals, n = matrix size
- No index arrays (unlike BCOO which stores (row, col) per non-zero)
_offsetsis static Python metadata (pytree_node=False), so JAX unrolls all loops over diagonals at compile time — only static slices, no dynamic indexing or scatter/gather.
Padding convention (Convention A):
For diagonal at offset k (k ≥ 0): diags[..., i, j] = A[j-k, j] for j ∈ [k, n-1]; zeros at [0:k] For diagonal at offset k (k < 0): diags[..., i, j] = A[j-k, j] for j ∈ [0, n+k-1]; zeros at [n+k:]
Unified access formula (holds for any k, out-of-range slots are zero): A[i, i+k] = diags[..., diag_idx, i+k]
This makes every matrix operation a set of aligned slice multiplications.
Some improvements (_dia_slice helper, integer matrix power, diagonal-range pruning, offset detection) were identified by studying the dynamiqs library (https://github.com/dynamiqs/dynamiqs).
SparseDiaData
Lightweight pytree-compatible container for sparse-diagonal raw data.
Returned by SparseDiaImpl.get_data() and consumed by
SparseDiaImpl.from_data(). Registered as a JAX pytree via Flax's
@struct.dataclass; offsets is not a pytree leaf (it is static
compile-time metadata).
Attributes:
| Name | Type | Description |
|---|---|---|
offsets |
tuple
|
Static tuple of diagonal offsets (pytree_node=False). |
diags |
Array
|
JAX array of shape (*batch, n_diags, n) containing the padded diagonal values. |
Source code in jaxquantum/core/sparse_dia.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | |
dtype
property
Dtype of the stored diagonal values.
shape
property
Shape of the represented square matrix (*batch, n, n).
__getitem__(index)
Index into the batch dimension(s), preserving offsets.
Source code in jaxquantum/core/sparse_dia.py
105 106 107 | |
__len__()
Number of elements along the leading batch dimension.
Source code in jaxquantum/core/sparse_dia.py
109 110 111 | |
__matmul__(other)
SparseDIA @ dense → dense (used by mesolve ODE RHS).
Source code in jaxquantum/core/sparse_dia.py
124 125 126 127 128 | |
__rmatmul__(other)
dense @ SparseDIA → dense (used by mesolve ODE RHS).
Source code in jaxquantum/core/sparse_dia.py
130 131 132 | |
reshape(*new_shape)
Reshape batch dimensions while preserving diagonal structure.
new_shape must end with (N, N) (the matrix dims are unchanged).
Only the leading batch dims are reshaped.
Source code in jaxquantum/core/sparse_dia.py
113 114 115 116 117 118 119 120 121 122 | |
SparseDiaImpl
Bases: QarrayImpl
Sparse-diagonal backend storing only diagonal values.
Data layout::
_offsets : tuple[int, ...] — static (pytree_node=False)
_diags : Array[*batch, n_diags, n] — JAX-traced values
For offset k, the convention is:
* k ≥ 0 : valid data at _diags[..., i, k:], zeros at [0:k]
* k < 0 : valid data at _diags[..., i, :n+k], zeros at [n+k:]
In both cases: A[row, row+k] = _diags[..., i, row+k]
Source code in jaxquantum/core/sparse_dia.py
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 | |
add(other)
Element-wise addition.
SparseDIA + SparseDIA stays SparseDIA (union of offsets, static). Otherwise coerces to the higher-order type.
Source code in jaxquantum/core/sparse_dia.py
280 281 282 283 284 285 286 287 288 289 290 291 | |
can_handle_data(arr)
classmethod
Return True only for :class:SparseDiaData objects.
Source code in jaxquantum/core/sparse_dia.py
404 405 406 407 | |
conj()
Element-wise complex conjugate of stored values.
Source code in jaxquantum/core/sparse_dia.py
449 450 451 | |
dag()
Conjugate transpose without densification.
Negates every offset and rearranges the stored values so that the padding convention remains consistent.
Source code in jaxquantum/core/sparse_dia.py
324 325 326 327 328 329 330 331 332 333 334 335 336 | |
dag_data(arr)
classmethod
Conjugate transpose of raw :class:SparseDiaData without densification.
Source code in jaxquantum/core/sparse_dia.py
409 410 411 412 413 414 | |
dtype()
Dtype of the stored diagonal values.
Source code in jaxquantum/core/sparse_dia.py
258 259 260 | |
frobenius_norm()
Frobenius norm computed directly from stored diagonal values.
Source code in jaxquantum/core/sparse_dia.py
431 432 433 | |
from_data(data)
classmethod
Wrap data in a new SparseDiaImpl.
Accepts either a :class:SparseDiaData container (direct wrap) or
a dense array-like (auto-detect non-zero diagonals via numpy, safe
to call before JIT).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
A :class: |
required |
Returns:
| Type | Description |
|---|---|
'SparseDiaImpl'
|
A new |
Source code in jaxquantum/core/sparse_dia.py
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | |
from_diags(offsets, diags)
classmethod
Directly construct from sorted offsets and padded diagonal array.
This is the preferred factory when diagonal structure is known in
advance (e.g., when building destroy or create operators).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
offsets
|
tuple
|
Tuple of integer diagonal offsets (need not be sorted; will be sorted internally). |
required |
diags
|
Array
|
JAX array of shape (batch, n_diags, n) with padded diagonal values matching *offsets. |
required |
Returns:
| Type | Description |
|---|---|
'SparseDiaImpl'
|
A new |
Source code in jaxquantum/core/sparse_dia.py
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | |
get_data()
Return a :class:SparseDiaData container with the raw diagonal data.
Source code in jaxquantum/core/sparse_dia.py
249 250 251 | |
imag()
Element-wise imaginary part of stored values.
Source code in jaxquantum/core/sparse_dia.py
442 443 444 445 446 447 | |
kron(other)
Kronecker product.
SparseDIA ⊗ SparseDIA stays SparseDIA: output offset for pair
(kA, kB) is kA * m + kB where m = dim(B). Fully vectorised —
no loops at JAX level.
Source code in jaxquantum/core/sparse_dia.py
338 339 340 341 342 343 344 345 346 347 348 349 350 | |
matmul(other)
Matrix multiplication.
- SparseDIA @ SparseDIA → SparseDIA (O(d₁·d₂·n))
- SparseDIA @ Dense → Dense (O(d·n²), no densification of self)
- Others → coerce then delegate
Source code in jaxquantum/core/sparse_dia.py
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | |
mul(scalar)
Scalar multiplication — scales all diagonal values.
Source code in jaxquantum/core/sparse_dia.py
272 273 274 | |
neg()
Negation.
Source code in jaxquantum/core/sparse_dia.py
276 277 278 | |
powm(n)
Integer matrix power staying SparseDIA via binary exponentiation.
Uses O(log n) SparseDIA @ SparseDIA multiplications rather than densifying. A^0 returns the identity operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n
|
int
|
Non-negative integer exponent. |
required |
Returns:
| Type | Description |
|---|---|
'SparseDiaImpl'
|
A |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n is negative. |
Source code in jaxquantum/core/sparse_dia.py
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 | |
real()
Element-wise real part of stored values.
Source code in jaxquantum/core/sparse_dia.py
435 436 437 438 439 440 | |
shape()
Shape of the represented square matrix (including batch dims).
Source code in jaxquantum/core/sparse_dia.py
253 254 255 256 | |
sub(other)
Element-wise subtraction.
Source code in jaxquantum/core/sparse_dia.py
293 294 295 296 297 298 299 300 | |
tidy_up(atol)
Zero diagonal values whose magnitude is below atol.
Source code in jaxquantum/core/sparse_dia.py
352 353 354 355 356 357 358 359 360 361 | |
to_dense()
Convert to a DenseImpl by summing diagonal contributions.
Source code in jaxquantum/core/sparse_dia.py
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 | |
to_sparse_bcoo()
Convert to a SparseBCOOImpl (BCOO) via dense.
Source code in jaxquantum/core/sparse_dia.py
383 384 385 | |
to_sparse_dia()
Return self (already SparseDIA).
Source code in jaxquantum/core/sparse_dia.py
387 388 389 | |
trace()
Compute trace directly from the main diagonal (offset 0).
Returns:
| Type | Description |
|---|---|
|
Scalar trace (sum of main diagonal values). |
Source code in jaxquantum/core/sparse_dia.py
420 421 422 423 424 425 426 427 428 429 | |