Coverage for jaxquantum / core / sparse_bcoo.py: 92%

149 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 22:49 +0000

1"""Sparse BCOO backend for Qarray. 

2 

3Implements the JAX experimental BCOO sparse format as a Qarray storage backend. 

4""" 

5 

6from __future__ import annotations 

7 

8from flax import struct 

9from jax import Array 

10from copy import deepcopy 

11import jax.numpy as jnp 

12from jax.experimental import sparse 

13 

14# QarrayImpl and friends are imported below (after qarray.py is fully loaded) 

15# to match the pattern used by sparse_dia.py and avoid circular imports. 

16from jaxquantum.core.qarray import QarrayImpl, DenseImpl, QarrayImplType # noqa: E402 

17 

18 

19@struct.dataclass 

20class SparseBCOOImpl(QarrayImpl): 

21 """Sparse implementation using JAX experimental BCOO sparse arrays. 

22 

23 Attributes: 

24 _data: The underlying ``sparse.BCOO`` array. 

25 """ 

26 

27 _data: sparse.BCOO 

28 

29 PROMOTION_ORDER = 1 # noqa: RUF012 — not a struct field; no annotation intentional 

30 

31 @classmethod 

32 def from_data(cls, data) -> "SparseBCOOImpl": 

33 """Wrap *data* in a new ``SparseBCOOImpl``, converting to BCOO if needed. 

34 

35 Args: 

36 data: A ``sparse.BCOO`` or array-like input. 

37 

38 Returns: 

39 A ``SparseBCOOImpl`` wrapping a BCOO representation of *data*. 

40 """ 

41 return cls(_data=cls._to_sparse(data)) 

42 

43 def get_data(self) -> Array: 

44 """Return the underlying BCOO sparse array.""" 

45 return self._data 

46 

47 def matmul(self, other: QarrayImpl) -> QarrayImpl: 

48 """Matrix multiply ``self @ other``. 

49 

50 When *other* is a ``DenseImpl``, JAX's native BCOO @ dense path is 

51 used (no self-densification). When *other* is also a 

52 ``SparseBCOOImpl``, a sparse @ sparse product is performed. 

53 

54 Args: 

55 other: Right-hand operand. 

56 

57 Returns: 

58 A ``DenseImpl`` (sparse @ dense) or ``SparseBCOOImpl`` (sparse @ 

59 sparse) containing the matrix product. 

60 """ 

61 if isinstance(other, DenseImpl): 

62 return DenseImpl(self._data @ other._data) 

63 a, b = self._coerce(other) 

64 if a is not self: 

65 return a.matmul(b) 

66 return SparseBCOOImpl(self._data @ b._data) 

67 

68 def add(self, other: QarrayImpl) -> QarrayImpl: 

69 """Element-wise addition ``self + other``, coercing types as needed. 

70 

71 Args: 

72 other: Right-hand operand. 

73 

74 Returns: 

75 A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) sum. 

76 """ 

77 a, b = self._coerce(other) 

78 if a is not self: 

79 return a.add(b) 

80 x, y = self._data, b._data 

81 if x.indices.dtype != y.indices.dtype: 

82 y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape) 

83 return SparseBCOOImpl(x + y) 

84 

85 def sub(self, other: QarrayImpl) -> QarrayImpl: 

86 """Element-wise subtraction ``self - other``, coercing types as needed. 

87 

88 Args: 

89 other: Right-hand operand. 

90 

91 Returns: 

92 A ``SparseBCOOImpl`` (both sparse) or ``DenseImpl`` (mixed) difference. 

93 """ 

94 a, b = self._coerce(other) 

95 if a is not self: 

96 return a.sub(b) 

97 x, y = self._data, b._data 

98 if x.indices.dtype != y.indices.dtype: 

99 y = sparse.BCOO((y.data, y.indices.astype(x.indices.dtype)), shape=y.shape) 

100 return SparseBCOOImpl(x - y) 

101 

102 def mul(self, scalar) -> QarrayImpl: 

103 """Scalar multiplication. 

104 

105 Args: 

106 scalar: Scalar value. 

107 

108 Returns: 

109 A ``SparseBCOOImpl`` with each stored value multiplied by *scalar*. 

110 """ 

111 return SparseBCOOImpl(scalar * self._data) 

112 

113 def dag(self) -> QarrayImpl: 

114 """Conjugate transpose without densifying. 

115 

116 Transposes the last two dimensions of the BCOO array and conjugates 

117 the stored values. 

118 

119 Returns: 

120 A ``SparseBCOOImpl`` containing the conjugate transpose. 

121 """ 

122 ndim = self._data.ndim 

123 if ndim >= 2: 

124 permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) 

125 transposed_data = sparse.bcoo_transpose(self._data, permutation=permutation) 

126 else: 

127 transposed_data = self._data 

128 

129 conjugated_data = sparse.BCOO( 

130 (jnp.conj(transposed_data.data), transposed_data.indices), 

131 shape=transposed_data.shape, 

132 ) 

133 return SparseBCOOImpl(conjugated_data) 

134 

135 def to_dense(self) -> "DenseImpl": 

136 """Convert to a ``DenseImpl`` via ``todense()``. 

137 

138 Returns: 

139 A ``DenseImpl`` with the same values as this sparse array. 

140 """ 

141 return DenseImpl(self._data.todense()) 

142 

143 @classmethod 

144 def _to_sparse(cls, data) -> sparse.BCOO: 

145 """Convert *data* to a ``sparse.BCOO``, returning it unchanged if already sparse. 

146 

147 Args: 

148 data: A ``sparse.BCOO`` or array-like. 

149 

150 Returns: 

151 A ``sparse.BCOO`` representation of *data*. 

152 """ 

153 if isinstance(data, sparse.BCOO): 

154 return data 

155 return sparse.BCOO.fromdense(data) 

156 

157 def to_sparse_bcoo(self) -> "SparseBCOOImpl": 

158 """Return self (already sparse BCOO). 

159 

160 Returns: 

161 This ``SparseBCOOImpl`` instance unchanged. 

162 """ 

163 return self 

164 

165 def shape(self) -> tuple: 

166 """Shape of the underlying BCOO array. 

167 

168 Returns: 

169 Tuple of dimension sizes. 

170 """ 

171 return self._data.shape 

172 

173 def dtype(self): 

174 """Data type of the underlying BCOO array. 

175 

176 Returns: 

177 The dtype of ``_data``. 

178 """ 

179 return self._data.dtype 

180 

181 def frobenius_norm(self) -> float: 

182 """Compute the Frobenius norm directly from stored values. 

183 

184 Returns: 

185 The Frobenius norm as a scalar. 

186 """ 

187 return jnp.sqrt(jnp.sum(jnp.abs(self._data.data) ** 2)) 

188 

189 @classmethod 

190 def _real(cls, data): 

191 """Return a BCOO array with only the real parts of the stored values.""" 

192 return sparse.BCOO((jnp.real(data.data), data.indices), shape=data.shape) 

193 

194 def real(self) -> QarrayImpl: 

195 """Element-wise real part. 

196 

197 Returns: 

198 A ``SparseBCOOImpl`` containing the real parts of stored values. 

199 """ 

200 return SparseBCOOImpl(SparseBCOOImpl._real(self._data)) 

201 

202 @classmethod 

203 def _imag(cls, data): 

204 """Return a BCOO array with only the imaginary parts of the stored values.""" 

205 return sparse.BCOO((jnp.imag(data.data), data.indices), shape=data.shape) 

206 

207 def imag(self) -> QarrayImpl: 

208 """Element-wise imaginary part. 

209 

210 Returns: 

211 A ``SparseBCOOImpl`` containing the imaginary parts of stored values. 

212 """ 

213 return SparseBCOOImpl(SparseBCOOImpl._imag(self._data)) 

214 

215 @classmethod 

216 def _conj(cls, data): 

217 """Return a BCOO array with complex-conjugated stored values.""" 

218 return sparse.BCOO((jnp.conj(data.data), data.indices), shape=data.shape) 

219 

220 def conj(self) -> QarrayImpl: 

221 """Element-wise complex conjugate. 

222 

223 Returns: 

224 A ``SparseBCOOImpl`` containing the complex-conjugated stored values. 

225 """ 

226 return SparseBCOOImpl(SparseBCOOImpl._conj(self._data)) 

227 

228 @classmethod 

229 def _abs(cls, data): 

230 """Return a BCOO array with absolute values of stored entries.""" 

231 return sparse.sparsify(jnp.abs)(data) 

232 

233 def abs(self) -> QarrayImpl: 

234 """Element-wise absolute value. 

235 

236 Returns: 

237 A ``SparseBCOOImpl`` containing the absolute values of stored entries. 

238 """ 

239 return SparseBCOOImpl(SparseBCOOImpl._abs(self._data)) 

240 

241 @classmethod 

242 def _eye_data(cls, n: int, dtype=None): 

243 """Create an ``n x n`` identity matrix as a sparse BCOO with O(n) memory. 

244 

245 Args: 

246 n: Matrix size. 

247 dtype: Optional data type. 

248 

249 Returns: 

250 A ``sparse.BCOO`` identity matrix of shape ``(n, n)``. 

251 """ 

252 return sparse.eye(n, dtype=dtype) 

253 

254 @classmethod 

255 def can_handle_data(cls, arr) -> bool: 

256 """Return True when *arr* is a ``sparse.BCOO`` array. 

257 

258 Args: 

259 arr: Raw array. 

260 

261 Returns: 

262 True if *arr* is a ``sparse.BCOO`` instance. 

263 """ 

264 return isinstance(arr, sparse.BCOO) 

265 

266 @classmethod 

267 def dag_data(cls, arr: sparse.BCOO) -> sparse.BCOO: 

268 """Conjugate transpose for BCOO sparse arrays without densifying. 

269 

270 Args: 

271 arr: A ``sparse.BCOO`` array with ``ndim >= 2``. 

272 

273 Returns: 

274 A ``sparse.BCOO`` containing the conjugate transpose. 

275 """ 

276 ndim = arr.ndim 

277 permutation = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) 

278 transposed = sparse.bcoo_transpose(arr, permutation=permutation) 

279 return sparse.BCOO( 

280 (jnp.conj(transposed.data), transposed.indices), 

281 shape=transposed.shape, 

282 ) 

283 

284 def trace(self) -> Array: 

285 """Compute the trace of the last two matrix dimensions without densifying. 

286 

287 Returns: 

288 Trace value(s). 

289 """ 

290 indices = self._data.indices 

291 values = self._data.data 

292 ndim = indices.shape[-1] 

293 

294 is_diag = indices[:, -2] == indices[:, -1] 

295 

296 if ndim == 2: 

297 return jnp.sum(values * is_diag) 

298 else: 

299 batch_shape = self._data.shape[:-2] 

300 B = int(jnp.prod(jnp.array(batch_shape))) 

301 strides = [1] 

302 for s in reversed(batch_shape[1:]): 

303 strides.insert(0, strides[0] * s) 

304 strides = jnp.array(strides, dtype=jnp.int32) 

305 flat_batch_idx = jnp.sum(indices[:, :-2] * strides, axis=-1) 

306 result = jnp.zeros(B, dtype=values.dtype).at[flat_batch_idx].add( 

307 values * is_diag 

308 ) 

309 return result.reshape(batch_shape) 

310 

311 def keep_only_diag(self) -> "SparseBCOOImpl": 

312 """Zero out off-diagonal stored entries without densifying. 

313 

314 Returns: 

315 A ``SparseBCOOImpl`` with only diagonal entries non-zero. 

316 """ 

317 indices = self._data.indices 

318 values = self._data.data 

319 is_diag = indices[:, -2] == indices[:, -1] 

320 new_values = values * is_diag 

321 return SparseBCOOImpl(sparse.BCOO((new_values, indices), shape=self._data.shape)) 

322 

323 def l2_norm_batched(self, bdims: tuple) -> Array: 

324 """Compute the L2 norm per batch element without densifying. 

325 

326 Args: 

327 bdims: Tuple of batch dimension sizes. 

328 

329 Returns: 

330 Scalar or array of L2 norms. 

331 """ 

332 values = self._data.data 

333 indices = self._data.indices 

334 n_batch_dims = len(bdims) 

335 sq = jnp.abs(values) ** 2 

336 

337 if n_batch_dims == 0: 

338 return jnp.sqrt(jnp.sum(sq)) 

339 else: 

340 B = int(jnp.prod(jnp.array(bdims))) 

341 strides = [1] 

342 for s in reversed(bdims[1:]): 

343 strides.insert(0, strides[0] * s) 

344 strides = jnp.array(strides, dtype=jnp.int32) 

345 flat_batch_idx = jnp.sum(indices[:, :n_batch_dims] * strides, axis=-1) 

346 sum_sq = ( 

347 jnp.zeros(B, dtype=jnp.float64) 

348 .at[flat_batch_idx] 

349 .add(sq) 

350 ) 

351 return jnp.sqrt(sum_sq).reshape(bdims) 

352 

353 def __deepcopy__(self, memo=None): 

354 return SparseBCOOImpl(_data=deepcopy(self._data, memo)) 

355 

356 def tidy_up(self, atol) -> "SparseBCOOImpl": 

357 """Zero out stored values whose real or imaginary magnitude is below *atol*. 

358 

359 Args: 

360 atol: Absolute tolerance threshold. 

361 

362 Returns: 

363 A new ``SparseBCOOImpl`` with small values zeroed. 

364 """ 

365 values = self._data.data 

366 re = jnp.real(values) 

367 im = jnp.imag(values) 

368 new_values = re * (jnp.abs(re) > atol) + 1j * im * (jnp.abs(im) > atol) 

369 return SparseBCOOImpl( 

370 sparse.BCOO((new_values, self._data.indices), shape=self._data.shape) 

371 ) 

372 

373 def kron(self, other: "QarrayImpl") -> "QarrayImpl": 

374 """Kronecker product using ``sparsify(jnp.kron)`` — stays sparse. 

375 

376 Args: 

377 other: Right-hand operand. 

378 

379 Returns: 

380 A ``SparseBCOOImpl`` containing the Kronecker product when both 

381 operands are sparse; a ``DenseImpl`` when types differ. 

382 """ 

383 a, b = self._coerce(other) 

384 if a is not self: 

385 return a.kron(b) 

386 sparse_kron = sparse.sparsify(jnp.kron) 

387 return SparseBCOOImpl(sparse_kron(self._data, b._data)) 

388 

389 

390# Register with the enum registry 

391QarrayImplType.register(SparseBCOOImpl, QarrayImplType.SPARSE_BCOO) 

392 

393__all__ = ["SparseBCOOImpl"]