Coverage for jaxquantum/core/qarray.py: 75%
452 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-19 17:34 +0000
1"""QArray."""
3from __future__ import annotations
5from flax import struct
6from jax import Array, config
7from typing import List, Union
10from math import prod
11from copy import deepcopy
12from numpy import ndarray
13import jax.numpy as jnp
14import jax.scipy as jsp
16from jaxquantum.core.settings import SETTINGS
17from jaxquantum.utils.utils import robust_isscalar
18from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims
20config.update("jax_enable_x64", True)
23def tidy_up(data, atol):
24 data_re = jnp.real(data)
25 data_im = jnp.imag(data)
26 data_re_mask = jnp.abs(data_re) > atol
27 data_im_mask = jnp.abs(data_im) > atol
28 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask
29 return data_new
32@struct.dataclass # this allows us to send in and return Qarray from jitted functions
33class Qarray:
34 _data: Array
35 _qdims: Qdims = struct.field(pytree_node=False)
36 _bdims: tuple[int] = struct.field(pytree_node=False)
38 # Initialization ----
39 @classmethod
40 def create(cls, data, dims=None, bdims=None):
41 # Step 1: Prepare data ----
42 data = jnp.asarray(data)
44 if len(data.shape) == 1 and data.shape[0] > 0:
45 data = data.reshape(data.shape[0], 1)
47 if len(data.shape) >= 2:
48 if data.shape[-2] != data.shape[-1] and not (
49 data.shape[-2] == 1 or data.shape[-1] == 1
50 ):
51 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
53 if bdims is not None:
54 if len(data.shape) - len(bdims) == 1:
55 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
56 # ----
58 # Step 2: Prepare dimensions ----
59 if bdims is None:
60 bdims = tuple(data.shape[:-2])
62 if dims is None:
63 dims = ((data.shape[-2],), (data.shape[-1],))
65 if not isinstance(dims[0], (list, tuple)):
66 # This handles the case where only the hilbert space dimensions are sent in.
67 if data.shape[-1] == 1:
68 dims = (tuple(dims), tuple([1 for _ in dims]))
69 elif data.shape[-2] == 1:
70 dims = (tuple([1 for _ in dims]), tuple(dims))
71 else:
72 dims = (tuple(dims), tuple(dims))
73 else:
74 dims = (tuple(dims[0]), tuple(dims[1]))
76 check_dims(dims, bdims, data.shape)
78 qdims = Qdims(dims)
80 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill.
81 # It increases the compilation time, but only very slightly
82 # increased the runtime of the jit compiled function.
83 # We could instead use this tidy_up where we think we need it.
84 data = tidy_up(data, SETTINGS["auto_tidyup_atol"])
86 return cls(data, qdims, bdims)
88 # ----
90 @classmethod
91 def from_list(cls, qarr_list: List[Qarray]) -> Qarray:
92 """Create a Qarray from a list of Qarrays."""
94 data = jnp.array([qarr.data for qarr in qarr_list])
96 if len(qarr_list) == 0:
97 dims = ((), ())
98 bdims = ()
99 else:
100 dims = qarr_list[0].dims
101 bdims = qarr_list[0].bdims
103 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list):
104 raise ValueError("All Qarrays in the list must have the same dimensions.")
106 bdims = (len(qarr_list),) + bdims
108 return cls.create(data, dims=dims, bdims=bdims)
110 @classmethod
111 def from_array(cls, qarr_arr) -> Qarray:
112 """Create a Qarray from a nested list of Qarrays.
114 Args:
115 qarr_arr (list): nested list of Qarrays
117 Returns:
118 Qarray: Qarray object
119 """
120 if isinstance(qarr_arr, Qarray):
121 return qarr_arr
123 bdims = ()
124 lvl = qarr_arr
125 while not isinstance(lvl, Qarray):
126 bdims = bdims + (len(lvl),)
127 if len(lvl) > 0:
128 lvl = lvl[0]
129 else:
130 break
132 def flat(lis):
133 flatList = []
134 for element in lis:
135 if type(element) is list:
136 flatList += flat(element)
137 else:
138 flatList.append(element)
139 return flatList
141 qarr_list = flat(qarr_arr)
142 qarr = cls.from_list(qarr_list)
143 qarr = qarr.reshape_bdims(*bdims)
144 return qarr
146 # Properties ----
147 @property
148 def qtype(self):
149 return self._qdims.qtype
151 @property
152 def dtype(self):
153 return self._data.dtype
155 @property
156 def dims(self):
157 return self._qdims.dims
159 @property
160 def bdims(self):
161 return self._bdims
163 @property
164 def qdims(self):
165 return self._qdims
167 @property
168 def space_dims(self):
169 if self.qtype in [Qtypes.oper, Qtypes.ket]:
170 return self.dims[0]
171 elif self.qtype == Qtypes.bra:
172 return self.dims[1]
173 else:
174 # TODO: not reached for some reason
175 raise ValueError("Unsupported qtype.")
177 @property
178 def data(self):
179 return self._data
181 @property
182 def shaped_data(self):
183 return self._data.reshape(self.bdims + self.dims[0] + self.dims[1])
185 @property
186 def shape(self):
187 return self.data.shape
189 @property
190 def is_batched(self):
191 return len(self.bdims) > 0
193 def __getitem__(self, index):
194 if len(self.bdims) > 0:
195 return Qarray.create(
196 self.data[index],
197 dims=self.dims,
198 )
199 else:
200 raise ValueError("Cannot index a non-batched Qarray.")
202 def reshape_bdims(self, *args):
203 """Reshape the batch dimensions of the Qarray."""
204 new_bdims = tuple(args)
206 if prod(new_bdims) == 0:
207 new_shape = new_bdims
208 else:
209 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,)
210 return Qarray.create(
211 self.data.reshape(new_shape),
212 dims=self.dims,
213 bdims=new_bdims,
214 )
216 def space_to_qdims(self, space_dims: List[int]):
217 if isinstance(space_dims[0], (list, tuple)):
218 return space_dims
220 if self.qtype in [Qtypes.oper, Qtypes.ket]:
221 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))]))
222 elif self.qtype == Qtypes.bra:
223 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims))
224 else:
225 raise ValueError("Unsupported qtype for space_to_qdims conversion.")
227 def reshape_qdims(self, *args):
228 """Reshape the quantum dimensions of the Qarray.
230 Note that this does not take in qdims but rather the new Hilbert space dimensions.
232 Args:
233 *args: new Hilbert dimensions for the Qarray.
235 Returns:
236 Qarray: reshaped Qarray.
237 """
239 new_space_dims = tuple(args)
240 current_space_dims = self.space_dims
241 assert prod(new_space_dims) == prod(current_space_dims)
243 new_qdims = self.space_to_qdims(new_space_dims)
244 new_bdims = self.bdims
246 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims)
248 def resize(self, new_shape):
249 """Resize the Qarray to a new shape.
251 TODO: review and maybe deprecate this method.
252 """
253 dims = self.dims
254 data = jnp.resize(self.data, new_shape)
255 return Qarray.create(
256 data,
257 dims=dims,
258 )
260 def __len__(self):
261 """Length of the Qarray."""
262 if len(self.bdims) > 0:
263 return self.data.shape[0]
264 else:
265 raise ValueError("Cannot get length of a non-batched Qarray.")
267 def __eq__(self, other):
268 if not isinstance(other, Qarray):
269 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.")
271 if self.dims != other.dims:
272 return False
274 if self.bdims != other.bdims:
275 return False
277 return jnp.all(self.data == other.data)
279 def __ne__(self, other):
280 return not self.__eq__(other)
282 # ----
284 # Elementary Math ----
285 def __matmul__(self, other):
286 if not isinstance(other, Qarray):
287 return NotImplemented
288 _qdims_new = self._qdims @ other._qdims
289 return Qarray.create(
290 self.data @ other.data,
291 dims=_qdims_new.dims,
292 )
294 # NOTE: not possible to reach this.
295 # def __rmatmul__(self, other):
296 # if not isinstance(other, Qarray):
297 # return NotImplemented
299 # _qdims_new = other._qdims @ self._qdims
300 # return Qarray.create(
301 # other.data @ self.data,
302 # dims=_qdims_new.dims,
303 # )
305 def __mul__(self, other):
306 if isinstance(other, Qarray):
307 return self.__matmul__(other)
309 other = other + 0.0j
310 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
311 other = other.reshape(other.shape + (1, 1))
313 return Qarray.create(
314 other * self.data,
315 dims=self._qdims.dims,
316 )
318 def __rmul__(self, other):
319 # NOTE: not possible to reach this.
320 # if isinstance(other, Qarray):
321 # return self.__rmatmul__(other)
323 return self.__mul__(other)
325 def __neg__(self):
326 return self.__mul__(-1)
328 def __truediv__(self, other):
329 """For Qarray's, this only really makes sense in the context of division by a scalar."""
331 if isinstance(other, Qarray):
332 raise ValueError("Cannot divide a Qarray by another Qarray.")
334 return self.__mul__(1 / other)
336 def __add__(self, other):
337 if isinstance(other, Qarray):
338 if self.dims != other.dims:
339 msg = (
340 "Dimensions are incompatible: "
341 + repr(self.dims)
342 + " and "
343 + repr(other.dims)
344 )
345 raise ValueError(msg)
346 return Qarray.create(self.data + other.data, dims=self.dims)
348 if robust_isscalar(other) and other == 0:
349 return self.copy()
351 if self.data.shape[-2] == self.data.shape[-1]:
352 other = other + 0.0j
353 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
354 other = other.reshape(other.shape + (1, 1))
355 other = Qarray.create(
356 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype),
357 dims=self.dims,
358 )
359 return self.__add__(other)
361 return NotImplemented
363 def __radd__(self, other):
364 return self.__add__(other)
366 def __sub__(self, other):
367 if isinstance(other, Qarray):
368 if self.dims != other.dims:
369 msg = (
370 "Dimensions are incompatible: "
371 + repr(self.dims)
372 + " and "
373 + repr(other.dims)
374 )
375 raise ValueError(msg)
376 return Qarray.create(self.data - other.data, dims=self.dims)
378 if robust_isscalar(other) and other == 0:
379 return self.copy()
381 if self.data.shape[-2] == self.data.shape[-1]:
382 other = other + 0.0j
383 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
384 other = other.reshape(other.shape + (1, 1))
385 other = Qarray.create(
386 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype),
387 dims=self.dims,
388 )
389 return self.__sub__(other)
391 return NotImplemented
393 def __rsub__(self, other):
394 return self.__neg__().__add__(other)
396 def __xor__(self, other):
397 if not isinstance(other, Qarray):
398 return NotImplemented
399 return tensor(self, other)
401 def __rxor__(self, other):
402 if not isinstance(other, Qarray):
403 return NotImplemented
404 return tensor(other, self)
406 def __pow__(self, other):
407 if not isinstance(other, int):
408 return NotImplemented
410 return powm(self, other)
412 # ----
414 # String Representation ----
415 def _str_header(self):
416 out = ", ".join(
417 [
418 "Quantum array: dims = " + str(self.dims),
419 "bdims = " + str(self.bdims),
420 "shape = " + str(self._data.shape),
421 "type = " + str(self.qtype),
422 ]
423 )
424 return out
426 def __str__(self):
427 return self._str_header() + "\nQarray data =\n" + str(self.data)
429 @property
430 def header(self):
431 """Print the header of the Qarray."""
432 return self._str_header()
434 def __repr__(self):
435 return self.__str__()
437 # ----
439 # Utilities ----
440 def copy(self, memo=None):
441 # return Qarray.create(deepcopy(self.data), dims=self.dims)
442 return self.__deepcopy__(memo)
444 def __deepcopy__(self, memo):
445 """Need to override this when defininig __getattr__."""
447 return Qarray(
448 _data=deepcopy(self._data, memo=memo),
449 _qdims=deepcopy(self._qdims, memo=memo),
450 _bdims=deepcopy(self._bdims, memo=memo),
451 )
453 def __getattr__(self, method_name):
454 if "__" == method_name[:2]:
455 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__
456 return lambda *args, **kwargs: NotImplemented
458 modules = [jnp, jnp.linalg, jsp, jsp.linalg]
460 method_f = None
461 for mod in modules:
462 method_f = getattr(mod, method_name, None)
463 if method_f is not None:
464 break
466 if method_f is None:
467 raise NotImplementedError(
468 f"Method {method_name} does not exist. No backup method found in {modules}."
469 )
471 def func(*args, **kwargs):
472 res = method_f(self.data, *args, **kwargs)
474 if getattr(res, "shape", None) is None or res.shape != self.data.shape:
475 return res
476 else:
477 return Qarray.create(res, dims=self._qdims.dims)
479 return func
481 # ----
483 # Conversions / Reshaping ----
484 def dag(self):
485 return dag(self)
487 def to_dm(self):
488 return ket2dm(self)
490 def is_dm(self):
491 return self.qtype == Qtypes.oper
493 def is_vec(self):
494 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra
496 def to_ket(self):
497 return to_ket(self)
499 def transpose(self, *args):
500 return transpose(self, *args)
502 def keep_only_diag_elements(self):
503 return keep_only_diag_elements(self)
505 # ----
507 # Math Functions ----
508 def unit(self):
509 return unit(self)
511 def norm(self):
512 return norm(self)
514 def expm(self):
515 return expm(self)
517 def powm(self, n):
518 return powm(self, n)
520 def cosm(self):
521 return cosm(self)
523 def sinm(self):
524 return sinm(self)
526 def tr(self, **kwargs):
527 return tr(self, **kwargs)
529 def trace(self, **kwargs):
530 return tr(self, **kwargs)
532 def ptrace(self, indx):
533 return ptrace(self, indx)
535 def eigenstates(self):
536 return eigenstates(self)
538 def eigenenergies(self):
539 return eigenenergies(self)
541 def collapse(self, mode="sum"):
542 return collapse(self, mode=mode)
544 # ----
547ARRAY_TYPES = (Array, ndarray, Qarray)
549# Qarray operations ---------------------------------------------------------------------
551def concatenate(qarr_list: List[Qarray], axis: int = 0) -> Qarray:
552 """Concatenate a list of Qarrays along a specified axis.
554 Args:
555 qarr_list (List[Qarray]): List of Qarrays to concatenate.
556 axis (int): Axis along which to concatenate. Default is 0.
558 Returns:
559 Qarray: Concatenated Qarray.
560 """
562 non_empty_qarr_list = [qarr for qarr in qarr_list if len(qarr.data) != 0]
564 if len(non_empty_qarr_list) == 0:
565 return Qarray.from_list([])
567 concatenated_data = jnp.concatenate(
568 [qarr.data for qarr in non_empty_qarr_list], axis=axis
569 )
571 dims = non_empty_qarr_list[0].dims
572 return Qarray.create(concatenated_data, dims=dims)
575def collapse(qarr: Qarray, mode="sum") -> Qarray:
576 """Collapse the Qarray.
578 Args:
579 qarr (Qarray): quantum array array
581 Returns:
582 Collapsed quantum array
583 """
584 if mode == "sum":
585 if len(qarr.bdims) == 0:
586 return qarr
588 batch_axes = list(range(len(qarr.bdims)))
589 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims)
592def transpose(qarr: Qarray, indices: List[int]) -> Qarray:
593 """Transpose the quantum array.
595 Args:
596 qarr (Qarray): quantum array
597 *args: axes to transpose
599 Returns:
600 tranposed Qarray
601 """
603 indices = list(indices)
605 shaped_data = qarr.shaped_data
606 dims = qarr.dims
607 bdims_indxs = list(range(len(qarr.bdims)))
609 reshape_indices = indices + [j + len(dims[0]) for j in indices]
611 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices]
613 shaped_data = shaped_data.transpose(reshape_indices)
614 new_dims = (
615 tuple([dims[0][j] for j in indices]),
616 tuple([dims[1][j] for j in indices]),
617 )
619 full_dims = prod(dims[0])
620 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1)
621 return Qarray.create(full_data, dims=new_dims)
624def unit(qarr: Qarray) -> Qarray:
625 """Normalize the quantum array.
627 Args:
628 qarr (Qarray): quantum array
630 Returns:
631 Normalized quantum array
632 """
633 data = qarr.data
634 data = data / qarr.norm()
635 return Qarray.create(data, dims=qarr.dims)
638def norm(qarr: Qarray) -> float:
639 data = qarr.data
640 data_dag = qarr.dag().data
642 if qarr.qtype == Qtypes.oper:
643 evals, _ = jnp.linalg.eigh(data @ data_dag)
644 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)))
645 return rho_norm
646 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]:
647 return jnp.linalg.norm(data)
650def tensor(*args, **kwargs) -> Qarray:
651 """Tensor product.
653 Args:
654 *args (Qarray): tensors to take the product of
655 parallel (bool): if True, use parallel einsum for tensor product
656 true: [A,B] ^ [C,D] = [A^C, B^D]
657 false (default): [A,B] ^ [C,D] = [A^C, A^D, B^C, B^D]
659 Returns:
660 Tensor product of given tensors
662 """
664 parallel = kwargs.pop("parallel", False)
666 data = args[0].data
667 dims = deepcopy(args[0].dims)
668 dims_0 = dims[0]
669 dims_1 = dims[1]
670 for arg in args[1:]:
671 if parallel:
672 a = data
673 b = arg.data
675 if len(a.shape) > len(b.shape):
676 batch_dim = a.shape[:-2]
677 elif len(a.shape) == len(b.shape):
678 if prod(a.shape[:-2]) > prod(b.shape[:-2]):
679 batch_dim = a.shape[:-2]
680 else:
681 batch_dim = b.shape[:-2]
682 else:
683 batch_dim = b.shape[:-2]
685 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape(
686 *batch_dim, a.shape[-2] * b.shape[-2], -1
687 )
688 else:
689 data = jnp.kron(data, arg.data, **kwargs)
691 dims_0 = dims_0 + arg.dims[0]
692 dims_1 = dims_1 + arg.dims[1]
694 return Qarray.create(data, dims=(dims_0, dims_1))
697def tr(qarr: Qarray, **kwargs) -> Array:
698 """Full trace.
700 Args:
701 qarr (Qarray): quantum array
703 Returns:
704 Full trace.
705 """
706 axis1 = kwargs.get("axis1", -2)
707 axis2 = kwargs.get("axis2", -1)
708 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs)
711def trace(qarr: Qarray, **kwargs) -> Array:
712 """Full trace.
714 Args:
715 qarr (Qarray): quantum array
717 Returns:
718 Full trace.
719 """
720 return tr(qarr, **kwargs)
723def expm_data(data: Array, **kwargs) -> Array:
724 """Matrix exponential wrapper.
726 Returns:
727 matrix exponential
728 """
729 return jsp.linalg.expm(data, **kwargs)
732def expm(qarr: Qarray, **kwargs) -> Qarray:
733 """Matrix exponential wrapper.
735 Returns:
736 matrix exponential
737 """
738 dims = qarr.dims
739 data = expm_data(qarr.data, **kwargs)
740 return Qarray.create(data, dims=dims)
743def powm(qarr: Qarray, n: Union[int, float], clip_eigvals=False) -> Qarray:
744 """Matrix power.
746 Args:
747 qarr (Qarray): quantum array
748 n (int): power
749 clip_eigvals (bool): clip eigenvalues to always be able to compute
750 non-integer powers
752 Returns:
753 matrix power
754 """
755 if isinstance(n, int):
756 data_res = jnp.linalg.matrix_power(qarr.data, n)
757 else:
758 evalues, evectors = jnp.linalg.eig(qarr.data)
759 if clip_eigvals:
760 evalues = jnp.maximum(evalues, 0)
761 else:
762 if not (evalues >= 0).all():
763 raise ValueError(
764 "Non-integer power of a matrix can only be "
765 "computed if the matrix is positive semi-definite."
766 "Got a matrix with a negative eigenvalue."
767 )
768 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors)
769 return Qarray.create(data_res, dims=qarr.dims)
772def cosm_data(data: Array, **kwargs) -> Array:
773 """Matrix cosine wrapper.
775 Returns:
776 matrix cosine
777 """
778 return (expm_data(1j * data) + expm_data(-1j * data)) / 2
781def cosm(qarr: Qarray) -> Qarray:
782 """Matrix cosine wrapper.
784 Args:
785 qarr (Qarray): quantum array
787 Returns:
788 matrix cosine
789 """
790 dims = qarr.dims
791 data = cosm_data(qarr.data)
792 return Qarray.create(data, dims=dims)
795def sinm_data(data: Array, **kwargs) -> Array:
796 """Matrix sine wrapper.
798 Args:
799 data: matrix
801 Returns:
802 matrix sine
803 """
804 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j)
807def sinm(qarr: Qarray) -> Qarray:
808 dims = qarr.dims
809 data = sinm_data(qarr.data)
810 return Qarray.create(data, dims=dims)
813def keep_only_diag_elements(qarr: Qarray) -> Qarray:
814 if len(qarr.bdims) > 0:
815 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.")
817 dims = qarr.dims
818 data = jnp.diag(jnp.diag(qarr.data))
819 return Qarray.create(data, dims=dims)
822def to_ket(qarr: Qarray) -> Qarray:
823 if qarr.qtype == Qtypes.ket:
824 return qarr
825 elif qarr.qtype == Qtypes.bra:
826 return qarr.dag()
827 else:
828 raise ValueError("Can only get ket from a ket or bra.")
831def eigenstates(qarr: Qarray) -> Qarray:
832 """Eigenstates of a quantum array.
834 Args:
835 qarr (Qarray): quantum array
837 Returns:
838 eigenvalues and eigenstates
839 """
841 evals, evecs = jnp.linalg.eigh(qarr.data)
842 idxs_sorted = jnp.argsort(evals, axis=-1)
844 dims = ket_from_op_dims(qarr.dims)
846 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1)
847 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1)
849 evecs = Qarray.create(
850 evecs,
851 dims=dims,
852 bdims=evecs.shape[:-1],
853 )
855 return evals, evecs
858def eigenenergies(qarr: Qarray) -> Array:
859 """Eigenvalues of a quantum array.
861 Args:
862 qarr (Qarray): quantum array
864 Returns:
865 eigenvalues
866 """
868 evals = jnp.linalg.eigvalsh(qarr.data)
869 return evals
872# More quantum specific -----------------------------------------------------
875def ptrace(qarr: Qarray, indx) -> Qarray:
876 """Partial Trace.
878 Args:
879 rho: density matrix
880 indx: index of quantum object to keep, rest will be partial traced out
882 Returns:
883 partial traced out density matrix
885 TODO: Fix weird tracing errors that arise with reshape
886 """
888 qarr = ket2dm(qarr)
889 rho = qarr.shaped_data
890 dims = qarr.dims
892 Nq = len(dims[0])
894 indxs = [indx, indx + Nq]
895 for j in range(Nq):
896 if j == indx:
897 continue
898 indxs.append(j)
899 indxs.append(j + Nq)
901 bdims = qarr.bdims
902 len_bdims = len(bdims)
903 bdims_indxs = list(range(len_bdims))
904 indxs = bdims_indxs + [j + len_bdims for j in indxs]
905 rho = rho.transpose(indxs)
907 for j in range(Nq - 1):
908 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims)
910 return Qarray.create(rho)
913def dag(qarr: Qarray) -> Qarray:
914 """Conjugate transpose.
916 Args:
917 qarr (Qarray): quantum array
919 Returns:
920 conjugate transpose of qarr
921 """
922 dims = qarr.dims[::-1]
924 data = dag_data(qarr.data)
926 return Qarray.create(data, dims=dims)
929def dag_data(arr: Array) -> Array:
930 """Conjugate transpose.
932 Args:
933 arr: operator
935 Returns:
936 conjugate of op, and transposes last two axes
937 """
938 # TODO: revisit this case...
939 if len(arr.shape) == 1:
940 return jnp.conj(arr)
942 return jnp.moveaxis(
943 jnp.conj(arr), -1, -2
944 ) # transposes last two axes, good for batching
947def ket2dm(qarr: Qarray) -> Qarray:
948 """Turns ket into density matrix.
949 Does nothing if already operator.
951 Args:
952 qarr (Qarray): qarr
954 Returns:
955 Density matrix
956 """
958 if qarr.qtype == Qtypes.oper:
959 return qarr
961 if qarr.qtype == Qtypes.bra:
962 qarr = qarr.dag()
964 return qarr @ qarr.dag()
967# Data level operations ----
970def is_dm_data(data: Array) -> bool:
971 """Check if data is a density matrix.
973 Args:
974 data: matrix
975 Returns:
976 True if data is a density matrix
977 """
978 return data.shape[-2] == data.shape[-1]
981def powm_data(data: Array, n: int) -> Array:
982 """Matrix power.
984 Args:
985 data: matrix
986 n: power
988 Returns:
989 matrix power
990 """
991 return jnp.linalg.matrix_power(data, n)