Coverage for jaxquantum/core/qarray.py: 71%
437 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""QArray."""
3from __future__ import annotations
5from flax import struct
6from jax import Array, config
7from typing import List, Union
10from math import prod
11from copy import deepcopy
12from numpy import ndarray
13import jax.numpy as jnp
14import jax.scipy as jsp
16from jaxquantum.core.settings import SETTINGS
17from jaxquantum.utils.utils import robust_isscalar
18from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims
20config.update("jax_enable_x64", True)
23def tidy_up(data, atol):
24 data_re = jnp.real(data)
25 data_im = jnp.imag(data)
26 data_re_mask = jnp.abs(data_re) > atol
27 data_im_mask = jnp.abs(data_im) > atol
28 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask
29 return data_new
32@struct.dataclass # this allows us to send in and return Qarray from jitted functions
33class Qarray:
34 _data: Array
35 _qdims: Qdims = struct.field(pytree_node=False)
36 _bdims: tuple[int] = struct.field(pytree_node=False)
38 # Initialization ----
39 @classmethod
40 def create(cls, data, dims=None, bdims=None):
41 # Step 1: Prepare data ----
42 data = jnp.asarray(data)
44 if len(data.shape) == 1 and data.shape[0] > 0:
45 data = data.reshape(data.shape[0], 1)
47 if len(data.shape) >= 2:
48 if data.shape[-2] != data.shape[-1] and not (
49 data.shape[-2] == 1 or data.shape[-1] == 1
50 ):
51 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
53 if bdims is not None:
54 if len(data.shape) - len(bdims) == 1:
55 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
56 # ----
58 # Step 2: Prepare dimensions ----
59 if bdims is None:
60 bdims = tuple(data.shape[:-2])
62 if dims is None:
63 dims = ((data.shape[-2],), (data.shape[-1],))
65 dims = (tuple(dims[0]), tuple(dims[1]))
67 check_dims(dims, bdims, data.shape)
69 qdims = Qdims(dims)
71 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill.
72 # It increases the compilation time, but only very slightly
73 # increased the runtime of the jit compiled function.
74 # We could instead use this tidy_up where we think we need it.
75 data = tidy_up(data, SETTINGS["auto_tidyup_atol"])
77 return cls(data, qdims, bdims)
79 # ----
81 @classmethod
82 def from_list(cls, qarr_list: List[Qarray]) -> Qarray:
83 """Create a Qarray from a list of Qarrays."""
85 data = jnp.array([qarr.data for qarr in qarr_list])
87 if len(qarr_list) == 0:
88 dims = ((), ())
89 bdims = ()
90 else:
91 dims = qarr_list[0].dims
92 bdims = qarr_list[0].bdims
94 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list):
95 raise ValueError("All Qarrays in the list must have the same dimensions.")
97 bdims = (len(qarr_list),) + bdims
99 return cls.create(data, dims=dims, bdims=bdims)
101 @classmethod
102 def from_array(cls, qarr_arr) -> Qarray:
103 """Create a Qarray from a nested list of Qarrays.
105 Args:
106 qarr_arr (list): nested list of Qarrays
108 Returns:
109 Qarray: Qarray object
110 """
111 if isinstance(qarr_arr, Qarray):
112 return qarr_arr
114 bdims = ()
115 lvl = qarr_arr
116 while not isinstance(lvl, Qarray):
117 bdims = bdims + (len(lvl),)
118 if len(lvl) > 0:
119 lvl = lvl[0]
120 else:
121 break
123 def flat(lis):
124 flatList = []
125 for element in lis:
126 if type(element) is list:
127 flatList += flat(element)
128 else:
129 flatList.append(element)
130 return flatList
132 qarr_list = flat(qarr_arr)
133 qarr = cls.from_list(qarr_list)
134 qarr = qarr.reshape_bdims(*bdims)
135 return qarr
137 # Properties ----
138 @property
139 def qtype(self):
140 return self._qdims.qtype
142 @property
143 def dtype(self):
144 return self._data.dtype
146 @property
147 def dims(self):
148 return self._qdims.dims
150 @property
151 def bdims(self):
152 return self._bdims
154 @property
155 def qdims(self):
156 return self._qdims
158 @property
159 def space_dims(self):
160 if self.qtype in [Qtypes.oper, Qtypes.ket]:
161 return self.dims[0]
162 elif self.qtype == Qtypes.bra:
163 return self.dims[1]
164 else:
165 # TODO: not reached for some reason
166 raise ValueError("Unsupported qtype.")
168 @property
169 def data(self):
170 return self._data
172 @property
173 def shaped_data(self):
174 return self._data.reshape(self.bdims + self.dims[0] + self.dims[1])
176 @property
177 def shape(self):
178 return self.data.shape
180 @property
181 def is_batched(self):
182 return len(self.bdims) > 0
184 def __getitem__(self, index):
185 if len(self.bdims) > 0:
186 return Qarray.create(
187 self.data[index],
188 dims=self.dims,
189 )
190 else:
191 raise ValueError("Cannot index a non-batched Qarray.")
193 def reshape_bdims(self, *args):
194 """Reshape the batch dimensions of the Qarray."""
195 new_bdims = tuple(args)
197 if prod(new_bdims) == 0:
198 new_shape = new_bdims
199 else:
200 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,)
201 return Qarray.create(
202 self.data.reshape(new_shape),
203 dims=self.dims,
204 bdims=new_bdims,
205 )
207 def space_to_qdims(self, space_dims: List[int]):
208 if isinstance(space_dims[0], (list, tuple)):
209 return space_dims
211 if self.qtype in [Qtypes.oper, Qtypes.ket]:
212 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))]))
213 elif self.qtype == Qtypes.bra:
214 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims))
215 else:
216 raise ValueError("Unsupported qtype for space_to_qdims conversion.")
218 def reshape_qdims(self, *args):
219 """Reshape the quantum dimensions of the Qarray.
221 Note that this does not take in qdims but rather the new Hilbert space dimensions.
223 Args:
224 *args: new Hilbert dimensions for the Qarray.
226 Returns:
227 Qarray: reshaped Qarray.
228 """
230 new_space_dims = tuple(args)
231 current_space_dims = self.space_dims
232 assert prod(new_space_dims) == prod(current_space_dims)
234 new_qdims = self.space_to_qdims(new_space_dims)
235 new_bdims = self.bdims
237 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims)
239 def resize(self, new_shape):
240 """Resize the Qarray to a new shape.
242 TODO: review and maybe deprecate this method.
243 """
244 dims = self.dims
245 data = jnp.resize(self.data, new_shape)
246 return Qarray.create(
247 data,
248 dims=dims,
249 )
251 def __len__(self):
252 """Length of the Qarray."""
253 if len(self.bdims) > 0:
254 return self.data.shape[0]
255 else:
256 raise ValueError("Cannot get length of a non-batched Qarray.")
258 def __eq__(self, other):
259 if not isinstance(other, Qarray):
260 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.")
262 if self.dims != other.dims:
263 return False
265 if self.bdims != other.bdims:
266 return False
268 return jnp.all(self.data == other.data)
270 def __ne__(self, other):
271 return not self.__eq__(other)
273 # ----
275 # Elementary Math ----
276 def __matmul__(self, other):
277 if not isinstance(other, Qarray):
278 return NotImplemented
279 _qdims_new = self._qdims @ other._qdims
280 return Qarray.create(
281 self.data @ other.data,
282 dims=_qdims_new.dims,
283 )
285 # NOTE: not possible to reach this.
286 # def __rmatmul__(self, other):
287 # if not isinstance(other, Qarray):
288 # return NotImplemented
290 # _qdims_new = other._qdims @ self._qdims
291 # return Qarray.create(
292 # other.data @ self.data,
293 # dims=_qdims_new.dims,
294 # )
296 def __mul__(self, other):
297 if isinstance(other, Qarray):
298 return self.__matmul__(other)
300 other = other + 0.0j
301 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
302 other = other.reshape(other.shape + (1, 1))
304 return Qarray.create(
305 other * self.data,
306 dims=self._qdims.dims,
307 )
309 def __rmul__(self, other):
310 # NOTE: not possible to reach this.
311 # if isinstance(other, Qarray):
312 # return self.__rmatmul__(other)
314 return self.__mul__(other)
316 def __neg__(self):
317 return self.__mul__(-1)
319 def __truediv__(self, other):
320 """For Qarray's, this only really makes sense in the context of division by a scalar."""
322 if isinstance(other, Qarray):
323 raise ValueError("Cannot divide a Qarray by another Qarray.")
325 return self.__mul__(1 / other)
327 def __add__(self, other):
328 if isinstance(other, Qarray):
329 if self.dims != other.dims:
330 msg = (
331 "Dimensions are incompatible: "
332 + repr(self.dims)
333 + " and "
334 + repr(other.dims)
335 )
336 raise ValueError(msg)
337 return Qarray.create(self.data + other.data, dims=self.dims)
339 if robust_isscalar(other) and other == 0:
340 return self.copy()
342 if self.data.shape[-2] == self.data.shape[-1]:
343 other = other + 0.0j
344 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
345 other = other.reshape(other.shape + (1, 1))
346 other = Qarray.create(
347 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype),
348 dims=self.dims,
349 )
350 return self.__add__(other)
352 return NotImplemented
354 def __radd__(self, other):
355 return self.__add__(other)
357 def __sub__(self, other):
358 if isinstance(other, Qarray):
359 if self.dims != other.dims:
360 msg = (
361 "Dimensions are incompatible: "
362 + repr(self.dims)
363 + " and "
364 + repr(other.dims)
365 )
366 raise ValueError(msg)
367 return Qarray.create(self.data - other.data, dims=self.dims)
369 if robust_isscalar(other) and other == 0:
370 return self.copy()
372 if self.data.shape[-2] == self.data.shape[-1]:
373 other = other + 0.0j
374 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
375 other = other.reshape(other.shape + (1, 1))
376 other = Qarray.create(
377 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype),
378 dims=self.dims,
379 )
380 return self.__sub__(other)
382 return NotImplemented
384 def __rsub__(self, other):
385 return self.__neg__().__add__(other)
387 def __xor__(self, other):
388 if not isinstance(other, Qarray):
389 return NotImplemented
390 return tensor(self, other)
392 def __rxor__(self, other):
393 if not isinstance(other, Qarray):
394 return NotImplemented
395 return tensor(other, self)
397 def __pow__(self, other):
398 if not isinstance(other, int):
399 return NotImplemented
401 return powm(self, other)
403 # ----
405 # String Representation ----
406 def _str_header(self):
407 out = ", ".join(
408 [
409 "Quantum array: dims = " + str(self.dims),
410 "bdims = " + str(self.bdims),
411 "shape = " + str(self._data.shape),
412 "type = " + str(self.qtype),
413 ]
414 )
415 return out
417 def __str__(self):
418 return self._str_header() + "\nQarray data =\n" + str(self.data)
420 @property
421 def header(self):
422 """Print the header of the Qarray."""
423 return self._str_header()
425 def __repr__(self):
426 return self.__str__()
428 # ----
430 # Utilities ----
431 def copy(self, memo=None):
432 # return Qarray.create(deepcopy(self.data), dims=self.dims)
433 return self.__deepcopy__(memo)
435 def __deepcopy__(self, memo):
436 """Need to override this when defininig __getattr__."""
438 return Qarray(
439 _data=deepcopy(self._data, memo=memo),
440 _qdims=deepcopy(self._qdims, memo=memo),
441 _bdims=deepcopy(self._bdims, memo=memo),
442 )
444 def __getattr__(self, method_name):
445 if "__" == method_name[:2]:
446 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__
447 return lambda *args, **kwargs: NotImplemented
449 modules = [jnp, jnp.linalg, jsp, jsp.linalg]
451 method_f = None
452 for mod in modules:
453 method_f = getattr(mod, method_name, None)
454 if method_f is not None:
455 break
457 if method_f is None:
458 raise NotImplementedError(
459 f"Method {method_name} does not exist. No backup method found in {modules}."
460 )
462 def func(*args, **kwargs):
463 res = method_f(self.data, *args, **kwargs)
465 if getattr(res, "shape", None) is None or res.shape != self.data.shape:
466 return res
467 else:
468 return Qarray.create(res, dims=self._qdims.dims)
470 return func
472 # ----
474 # Conversions / Reshaping ----
475 def dag(self):
476 return dag(self)
478 def to_dm(self):
479 return ket2dm(self)
481 def is_dm(self):
482 return self.qtype == Qtypes.oper
484 def is_vec(self):
485 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra
487 def to_ket(self):
488 return to_ket(self)
490 def transpose(self, *args):
491 return transpose(self, *args)
493 def keep_only_diag_elements(self):
494 return keep_only_diag_elements(self)
496 # ----
498 # Math Functions ----
499 def unit(self):
500 return unit(self)
502 def norm(self):
503 return norm(self)
505 def expm(self):
506 return expm(self)
508 def powm(self, n):
509 return powm(self, n)
511 def cosm(self):
512 return cosm(self)
514 def sinm(self):
515 return sinm(self)
517 def tr(self, **kwargs):
518 return tr(self, **kwargs)
520 def trace(self, **kwargs):
521 return tr(self, **kwargs)
523 def ptrace(self, indx):
524 return ptrace(self, indx)
526 def eigenstates(self):
527 return eigenstates(self)
529 def eigenenergies(self):
530 return eigenenergies(self)
532 def collapse(self, mode="sum"):
533 return collapse(self, mode=mode)
535 # ----
538ARRAY_TYPES = (Array, ndarray, Qarray)
540# Qarray operations ---------------------------------------------------------------------
543def collapse(qarr: Qarray, mode="sum") -> Qarray:
544 """Collapse the Qarray.
546 Args:
547 qarr (Qarray): quantum array array
549 Returns:
550 Collapsed quantum array
551 """
552 if mode == "sum":
553 if len(qarr.bdims) == 0:
554 return qarr
556 batch_axes = list(range(len(qarr.bdims)))
557 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims)
560def transpose(qarr: Qarray, indices: List[int]) -> Qarray:
561 """Transpose the quantum array.
563 Args:
564 qarr (Qarray): quantum array
565 *args: axes to transpose
567 Returns:
568 tranposed Qarray
569 """
571 indices = list(indices)
573 shaped_data = qarr.shaped_data
574 dims = qarr.dims
575 bdims_indxs = list(range(len(qarr.bdims)))
577 reshape_indices = indices + [j + len(dims[0]) for j in indices]
579 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices]
581 shaped_data = shaped_data.transpose(reshape_indices)
582 new_dims = (
583 tuple([dims[0][j] for j in indices]),
584 tuple([dims[1][j] for j in indices]),
585 )
587 full_dims = prod(dims[0])
588 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1)
589 return Qarray.create(full_data, dims=new_dims)
592def unit(qarr: Qarray) -> Qarray:
593 """Normalize the quantum array.
595 Args:
596 qarr (Qarray): quantum array
598 Returns:
599 Normalized quantum array
600 """
601 data = qarr.data
602 data = data / qarr.norm()
603 return Qarray.create(data, dims=qarr.dims)
606def norm(qarr: Qarray) -> float:
607 data = qarr.data
608 data_dag = qarr.dag().data
610 if qarr.qtype == Qtypes.oper:
611 evals, _ = jnp.linalg.eigh(data @ data_dag)
612 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)))
613 return rho_norm
614 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]:
615 return jnp.linalg.norm(data)
618def tensor(*args, **kwargs) -> Qarray:
619 """Tensor product.
621 Args:
622 *args (Qarray): tensors to take the product of
623 parallel (bool): if True, use parallel einsum for tensor product
624 true: [A,B] ^ [C,D] = [A^C, B^D]
625 false: [A,B] ^ [C,D] = [A^C, A^D, B^C, B^D]
627 Returns:
628 Tensor product of given tensors
630 """
632 parallel = kwargs.pop("parallel", False)
634 data = args[0].data
635 dims = deepcopy(args[0].dims)
636 dims_0 = dims[0]
637 dims_1 = dims[1]
638 for arg in args[1:]:
639 if parallel:
640 a = data
641 b = arg.data
643 if len(a.shape) > len(b.shape):
644 batch_dim = a.shape[:-2]
645 elif len(a.shape) == len(b.shape):
646 if prod(a.shape[:-2]) > prod(b.shape[:-2]):
647 batch_dim = a.shape[:-2]
648 else:
649 batch_dim = b.shape[:-2]
650 else:
651 batch_dim = b.shape[:-2]
653 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape(
654 *batch_dim, a.shape[-2] * b.shape[-2], -1
655 )
656 else:
657 data = jnp.kron(data, arg.data, **kwargs)
659 dims_0 = dims_0 + arg.dims[0]
660 dims_1 = dims_1 + arg.dims[1]
662 return Qarray.create(data, dims=(dims_0, dims_1))
665def tr(qarr: Qarray, **kwargs) -> Array:
666 """Full trace.
668 Args:
669 qarr (Qarray): quantum array
671 Returns:
672 Full trace.
673 """
674 axis1 = kwargs.get("axis1", -2)
675 axis2 = kwargs.get("axis2", -1)
676 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs)
679def trace(qarr: Qarray, **kwargs) -> Array:
680 """Full trace.
682 Args:
683 qarr (Qarray): quantum array
685 Returns:
686 Full trace.
687 """
688 return tr(qarr, **kwargs)
691def expm_data(data: Array, **kwargs) -> Array:
692 """Matrix exponential wrapper.
694 Returns:
695 matrix exponential
696 """
697 return jsp.linalg.expm(data, **kwargs)
700def expm(qarr: Qarray, **kwargs) -> Qarray:
701 """Matrix exponential wrapper.
703 Returns:
704 matrix exponential
705 """
706 dims = qarr.dims
707 data = expm_data(qarr.data, **kwargs)
708 return Qarray.create(data, dims=dims)
711def powm(qarr: Qarray, n: Union[int, float]) -> Qarray:
712 """Matrix power.
714 Args:
715 qarr (Qarray): quantum array
716 n (int): power
718 Returns:
719 matrix power
720 """
721 if isinstance(n, int):
722 data_res = jnp.linalg.matrix_power(qarr.data, n)
723 else:
724 evalues, evectors = jnp.linalg.eig(qarr.data)
725 if not (evalues >= 0).all():
726 raise ValueError(
727 "Non-integer power of a matrix can only be "
728 "computed if the matrix is positive semi-definite."
729 "Got a matrix with a negative eigenvalue."
730 )
731 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors)
732 return Qarray.create(data_res, dims=qarr.dims)
735def cosm_data(data: Array, **kwargs) -> Array:
736 """Matrix cosine wrapper.
738 Returns:
739 matrix cosine
740 """
741 return (expm_data(1j * data) + expm_data(-1j * data)) / 2
744def cosm(qarr: Qarray) -> Qarray:
745 """Matrix cosine wrapper.
747 Args:
748 qarr (Qarray): quantum array
750 Returns:
751 matrix cosine
752 """
753 dims = qarr.dims
754 data = cosm_data(qarr.data)
755 return Qarray.create(data, dims=dims)
758def sinm_data(data: Array, **kwargs) -> Array:
759 """Matrix sine wrapper.
761 Args:
762 data: matrix
764 Returns:
765 matrix sine
766 """
767 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j)
770def sinm(qarr: Qarray) -> Qarray:
771 dims = qarr.dims
772 data = sinm_data(qarr.data)
773 return Qarray.create(data, dims=dims)
776def keep_only_diag_elements(qarr: Qarray) -> Qarray:
777 if len(qarr.bdims) > 0:
778 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.")
780 dims = qarr.dims
781 data = jnp.diag(jnp.diag(qarr.data))
782 return Qarray.create(data, dims=dims)
785def to_ket(qarr: Qarray) -> Qarray:
786 if qarr.qtype == Qtypes.ket:
787 return qarr
788 elif qarr.qtype == Qtypes.bra:
789 return qarr.dag()
790 else:
791 raise ValueError("Can only get ket from a ket or bra.")
794def eigenstates(qarr: Qarray) -> Qarray:
795 """Eigenstates of a quantum array.
797 Args:
798 qarr (Qarray): quantum array
800 Returns:
801 eigenvalues and eigenstates
802 """
804 evals, evecs = jnp.linalg.eigh(qarr.data)
805 idxs_sorted = jnp.argsort(evals, axis=-1)
807 dims = ket_from_op_dims(qarr.dims)
809 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1)
810 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1)
812 evecs = Qarray.create(
813 evecs,
814 dims=dims,
815 bdims=evecs.shape[:-1],
816 )
818 return evals, evecs
821def eigenenergies(qarr: Qarray) -> Array:
822 """Eigenvalues of a quantum array.
824 Args:
825 qarr (Qarray): quantum array
827 Returns:
828 eigenvalues
829 """
831 evals = jnp.linalg.eigvalsh(qarr.data)
832 return evals
835# More quantum specific -----------------------------------------------------
838def ptrace(qarr: Qarray, indx) -> Qarray:
839 """Partial Trace.
841 Args:
842 rho: density matrix
843 indx: index of quantum object to keep, rest will be partial traced out
845 Returns:
846 partial traced out density matrix
848 TODO: Fix weird tracing errors that arise with reshape
849 """
851 qarr = ket2dm(qarr)
852 rho = qarr.shaped_data
853 dims = qarr.dims
855 Nq = len(dims[0])
857 indxs = [indx, indx + Nq]
858 for j in range(Nq):
859 if j == indx:
860 continue
861 indxs.append(j)
862 indxs.append(j + Nq)
864 bdims = qarr.bdims
865 len_bdims = len(bdims)
866 bdims_indxs = list(range(len_bdims))
867 indxs = bdims_indxs + [j + len_bdims for j in indxs]
868 rho = rho.transpose(indxs)
870 for j in range(Nq - 1):
871 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims)
873 return Qarray.create(rho)
876def dag(qarr: Qarray) -> Qarray:
877 """Conjugate transpose.
879 Args:
880 qarr (Qarray): quantum array
882 Returns:
883 conjugate transpose of qarr
884 """
885 dims = qarr.dims[::-1]
887 data = dag_data(qarr.data)
889 return Qarray.create(data, dims=dims)
892def dag_data(arr: Array) -> Array:
893 """Conjugate transpose.
895 Args:
896 arr: operator
898 Returns:
899 conjugate of op, and transposes last two axes
900 """
901 # TODO: revisit this case...
902 if len(arr.shape) == 1:
903 return jnp.conj(arr)
905 return jnp.moveaxis(
906 jnp.conj(arr), -1, -2
907 ) # transposes last two axes, good for batching
910def ket2dm(qarr: Qarray) -> Qarray:
911 """Turns ket into density matrix.
912 Does nothing if already operator.
914 Args:
915 qarr (Qarray): qarr
917 Returns:
918 Density matrix
919 """
921 if qarr.qtype == Qtypes.oper:
922 return qarr
924 if qarr.qtype == Qtypes.bra:
925 qarr = qarr.dag()
927 return qarr @ qarr.dag()
930# Data level operations ----
933def is_dm_data(data: Array) -> bool:
934 """Check if data is a density matrix.
936 Args:
937 data: matrix
938 Returns:
939 True if data is a density matrix
940 """
941 return data.shape[-2] == data.shape[-1]
944def powm_data(data: Array, n: int) -> Array:
945 """Matrix power.
947 Args:
948 data: matrix
949 n: power
951 Returns:
952 matrix power
953 """
954 return jnp.linalg.matrix_power(data, n)