Coverage for jaxquantum/core/qarray.py: 76%
461 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 19:55 +0000
1"""QArray."""
3from __future__ import annotations
5from flax import struct
6from jax import Array, config, vmap
7from typing import List, Union
10from math import prod
11from copy import deepcopy
12from numpy import ndarray
13import jax.numpy as jnp
14import jax.scipy as jsp
16from jaxquantum.core.settings import SETTINGS
17from jaxquantum.utils.utils import robust_isscalar
18from jaxquantum.core.dims import Qtypes, Qdims, check_dims, ket_from_op_dims
20config.update("jax_enable_x64", True)
23def tidy_up(data, atol):
24 data_re = jnp.real(data)
25 data_im = jnp.imag(data)
26 data_re_mask = jnp.abs(data_re) > atol
27 data_im_mask = jnp.abs(data_im) > atol
28 data_new = data_re * data_re_mask + 1j * data_im * data_im_mask
29 return data_new
32@struct.dataclass # this allows us to send in and return Qarray from jitted functions
33class Qarray:
34 _data: Array
35 _qdims: Qdims = struct.field(pytree_node=False)
36 _bdims: tuple[int] = struct.field(pytree_node=False)
38 # Initialization ----
39 @classmethod
40 def create(cls, data, dims=None, bdims=None):
41 # Step 1: Prepare data ----
42 data = jnp.asarray(data)
44 if len(data.shape) == 1 and data.shape[0] > 0:
45 data = data.reshape(data.shape[0], 1)
47 if len(data.shape) >= 2:
48 if data.shape[-2] != data.shape[-1] and not (
49 data.shape[-2] == 1 or data.shape[-1] == 1
50 ):
51 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
53 if bdims is not None:
54 if len(data.shape) - len(bdims) == 1:
55 data = data.reshape(*data.shape[:-1], data.shape[-1], 1)
56 # ----
58 # Step 2: Prepare dimensions ----
59 if bdims is None:
60 bdims = tuple(data.shape[:-2])
62 if dims is None:
63 dims = ((data.shape[-2],), (data.shape[-1],))
65 if not isinstance(dims[0], (list, tuple)):
66 # This handles the case where only the hilbert space dimensions are sent in.
67 if data.shape[-1] == 1:
68 dims = (tuple(dims), tuple([1 for _ in dims]))
69 elif data.shape[-2] == 1:
70 dims = (tuple([1 for _ in dims]), tuple(dims))
71 else:
72 dims = (tuple(dims), tuple(dims))
73 else:
74 dims = (tuple(dims[0]), tuple(dims[1]))
76 check_dims(dims, bdims, data.shape)
78 qdims = Qdims(dims)
80 # NOTE: Constantly tidying up on Qarray creation might be a bit overkill.
81 # It increases the compilation time, but only very slightly
82 # increased the runtime of the jit compiled function.
83 # We could instead use this tidy_up where we think we need it.
84 data = tidy_up(data, SETTINGS["auto_tidyup_atol"])
86 return cls(data, qdims, bdims)
88 # ----
90 @classmethod
91 def from_list(cls, qarr_list: List[Qarray]) -> Qarray:
92 """Create a Qarray from a list of Qarrays."""
94 data = jnp.array([qarr.data for qarr in qarr_list])
96 if len(qarr_list) == 0:
97 dims = ((), ())
98 bdims = ()
99 else:
100 dims = qarr_list[0].dims
101 bdims = qarr_list[0].bdims
103 if not all(qarr.dims == dims and qarr.bdims == bdims for qarr in qarr_list):
104 raise ValueError("All Qarrays in the list must have the same dimensions.")
106 bdims = (len(qarr_list),) + bdims
108 return cls.create(data, dims=dims, bdims=bdims)
110 @classmethod
111 def from_array(cls, qarr_arr) -> Qarray:
112 """Create a Qarray from a nested list of Qarrays.
114 Args:
115 qarr_arr (list): nested list of Qarrays
117 Returns:
118 Qarray: Qarray object
119 """
120 if isinstance(qarr_arr, Qarray):
121 return qarr_arr
123 bdims = ()
124 lvl = qarr_arr
125 while not isinstance(lvl, Qarray):
126 bdims = bdims + (len(lvl),)
127 if len(lvl) > 0:
128 lvl = lvl[0]
129 else:
130 break
132 def flat(lis):
133 flatList = []
134 for element in lis:
135 if type(element) is list:
136 flatList += flat(element)
137 else:
138 flatList.append(element)
139 return flatList
141 qarr_list = flat(qarr_arr)
142 qarr = cls.from_list(qarr_list)
143 qarr = qarr.reshape_bdims(*bdims)
144 return qarr
146 # Properties ----
147 @property
148 def qtype(self):
149 return self._qdims.qtype
151 @property
152 def dtype(self):
153 return self._data.dtype
155 @property
156 def dims(self):
157 return self._qdims.dims
159 @property
160 def bdims(self):
161 return self._bdims
163 @property
164 def qdims(self):
165 return self._qdims
167 @property
168 def space_dims(self):
169 if self.qtype in [Qtypes.oper, Qtypes.ket]:
170 return self.dims[0]
171 elif self.qtype == Qtypes.bra:
172 return self.dims[1]
173 else:
174 # TODO: not reached for some reason
175 raise ValueError("Unsupported qtype.")
177 @property
178 def data(self):
179 return self._data
181 @property
182 def shaped_data(self):
183 return self._data.reshape(self.bdims + self.dims[0] + self.dims[1])
185 @property
186 def shape(self):
187 return self.data.shape
189 @property
190 def is_batched(self):
191 return len(self.bdims) > 0
193 def __getitem__(self, index):
194 if len(self.bdims) > 0:
195 return Qarray.create(
196 self.data[index],
197 dims=self.dims,
198 )
199 else:
200 raise ValueError("Cannot index a non-batched Qarray.")
202 def reshape_bdims(self, *args):
203 """Reshape the batch dimensions of the Qarray."""
204 new_bdims = tuple(args)
206 if prod(new_bdims) == 0:
207 new_shape = new_bdims
208 else:
209 new_shape = new_bdims + (prod(self.dims[0]),) + (-1,)
210 return Qarray.create(
211 self.data.reshape(new_shape),
212 dims=self.dims,
213 bdims=new_bdims,
214 )
216 def space_to_qdims(self, space_dims: List[int]):
217 if isinstance(space_dims[0], (list, tuple)):
218 return space_dims
220 if self.qtype in [Qtypes.oper, Qtypes.ket]:
221 return (tuple(space_dims), tuple([1 for _ in range(len(space_dims))]))
222 elif self.qtype == Qtypes.bra:
223 return (tuple([1 for _ in range(len(space_dims))]), tuple(space_dims))
224 else:
225 raise ValueError("Unsupported qtype for space_to_qdims conversion.")
227 def reshape_qdims(self, *args):
228 """Reshape the quantum dimensions of the Qarray.
230 Note that this does not take in qdims but rather the new Hilbert space dimensions.
232 Args:
233 *args: new Hilbert dimensions for the Qarray.
235 Returns:
236 Qarray: reshaped Qarray.
237 """
239 new_space_dims = tuple(args)
240 current_space_dims = self.space_dims
241 assert prod(new_space_dims) == prod(current_space_dims)
243 new_qdims = self.space_to_qdims(new_space_dims)
244 new_bdims = self.bdims
246 return Qarray.create(self.data, dims=new_qdims, bdims=new_bdims)
248 def resize(self, new_shape):
249 """Resize the Qarray to a new shape.
251 TODO: review and maybe deprecate this method.
252 """
253 dims = self.dims
254 data = jnp.resize(self.data, new_shape)
255 return Qarray.create(
256 data,
257 dims=dims,
258 )
260 def __len__(self):
261 """Length of the Qarray."""
262 if len(self.bdims) > 0:
263 return self.data.shape[0]
264 else:
265 raise ValueError("Cannot get length of a non-batched Qarray.")
267 def __eq__(self, other):
268 if not isinstance(other, Qarray):
269 raise ValueError("Cannot calculate equality of a Qarray with a non-Qarray.")
271 if self.dims != other.dims:
272 return False
274 if self.bdims != other.bdims:
275 return False
277 return jnp.all(self.data == other.data)
279 def __ne__(self, other):
280 return not self.__eq__(other)
282 # ----
284 # Elementary Math ----
285 def __matmul__(self, other):
286 if not isinstance(other, Qarray):
287 return NotImplemented
288 _qdims_new = self._qdims @ other._qdims
289 return Qarray.create(
290 self.data @ other.data,
291 dims=_qdims_new.dims,
292 )
294 # NOTE: not possible to reach this.
295 # def __rmatmul__(self, other):
296 # if not isinstance(other, Qarray):
297 # return NotImplemented
299 # _qdims_new = other._qdims @ self._qdims
300 # return Qarray.create(
301 # other.data @ self.data,
302 # dims=_qdims_new.dims,
303 # )
305 def __mul__(self, other):
306 if isinstance(other, Qarray):
307 return self.__matmul__(other)
309 other = other + 0.0j
310 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
311 other = other.reshape(other.shape + (1, 1))
313 return Qarray.create(
314 other * self.data,
315 dims=self._qdims.dims,
316 )
318 def __rmul__(self, other):
319 # NOTE: not possible to reach this.
320 # if isinstance(other, Qarray):
321 # return self.__rmatmul__(other)
323 return self.__mul__(other)
325 def __neg__(self):
326 return self.__mul__(-1)
328 def __truediv__(self, other):
329 """For Qarray's, this only really makes sense in the context of division by a scalar."""
331 if isinstance(other, Qarray):
332 raise ValueError("Cannot divide a Qarray by another Qarray.")
334 return self.__mul__(1 / other)
336 def __add__(self, other):
337 if isinstance(other, Qarray):
338 if self.dims != other.dims:
339 msg = (
340 "Dimensions are incompatible: "
341 + repr(self.dims)
342 + " and "
343 + repr(other.dims)
344 )
345 raise ValueError(msg)
346 return Qarray.create(self.data + other.data, dims=self.dims)
348 if robust_isscalar(other) and other == 0:
349 return self.copy()
351 if self.data.shape[-2] == self.data.shape[-1]:
352 other = other + 0.0j
353 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
354 other = other.reshape(other.shape + (1, 1))
355 other = Qarray.create(
356 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype),
357 dims=self.dims,
358 )
359 return self.__add__(other)
361 return NotImplemented
363 def __radd__(self, other):
364 return self.__add__(other)
366 def __sub__(self, other):
367 if isinstance(other, Qarray):
368 if self.dims != other.dims:
369 msg = (
370 "Dimensions are incompatible: "
371 + repr(self.dims)
372 + " and "
373 + repr(other.dims)
374 )
375 raise ValueError(msg)
376 return Qarray.create(self.data - other.data, dims=self.dims)
378 if robust_isscalar(other) and other == 0:
379 return self.copy()
381 if self.data.shape[-2] == self.data.shape[-1]:
382 other = other + 0.0j
383 if not robust_isscalar(other) and len(other.shape) > 0: # not a scalar
384 other = other.reshape(other.shape + (1, 1))
385 other = Qarray.create(
386 other * jnp.eye(self.data.shape[-2], dtype=self.data.dtype),
387 dims=self.dims,
388 )
389 return self.__sub__(other)
391 return NotImplemented
393 def __rsub__(self, other):
394 return self.__neg__().__add__(other)
396 def __xor__(self, other):
397 if not isinstance(other, Qarray):
398 return NotImplemented
399 return tensor(self, other)
401 def __rxor__(self, other):
402 if not isinstance(other, Qarray):
403 return NotImplemented
404 return tensor(other, self)
406 def __pow__(self, other):
407 if not isinstance(other, int):
408 return NotImplemented
410 return powm(self, other)
412 # ----
414 # String Representation ----
415 def _str_header(self):
416 out = ", ".join(
417 [
418 "Quantum array: dims = " + str(self.dims),
419 "bdims = " + str(self.bdims),
420 "shape = " + str(self._data.shape),
421 "type = " + str(self.qtype),
422 ]
423 )
424 return out
426 def __str__(self):
427 return self._str_header() + "\nQarray data =\n" + str(self.data)
429 @property
430 def header(self):
431 """Print the header of the Qarray."""
432 return self._str_header()
434 def __repr__(self):
435 return self.__str__()
437 # ----
439 # Utilities ----
440 def copy(self, memo=None):
441 # return Qarray.create(deepcopy(self.data), dims=self.dims)
442 return self.__deepcopy__(memo)
444 def __deepcopy__(self, memo):
445 """Need to override this when defininig __getattr__."""
447 return Qarray(
448 _data=deepcopy(self._data, memo=memo),
449 _qdims=deepcopy(self._qdims, memo=memo),
450 _bdims=deepcopy(self._bdims, memo=memo),
451 )
453 def __getattr__(self, method_name):
454 if "__" == method_name[:2]:
455 # NOTE: we return NotImplemented for binary special methods logic in python, plus things like __jax_array__
456 return lambda *args, **kwargs: NotImplemented
458 modules = [jnp, jnp.linalg, jsp, jsp.linalg]
460 method_f = None
461 for mod in modules:
462 method_f = getattr(mod, method_name, None)
463 if method_f is not None:
464 break
466 if method_f is None:
467 raise NotImplementedError(
468 f"Method {method_name} does not exist. No backup method found in {modules}."
469 )
471 def func(*args, **kwargs):
472 res = method_f(self.data, *args, **kwargs)
474 if getattr(res, "shape", None) is None or res.shape != self.data.shape:
475 return res
476 else:
477 return Qarray.create(res, dims=self._qdims.dims)
479 return func
481 # ----
483 # Conversions / Reshaping ----
484 def dag(self):
485 return dag(self)
487 def to_dm(self):
488 return ket2dm(self)
490 def is_dm(self):
491 return self.qtype == Qtypes.oper
493 def is_vec(self):
494 return self.qtype == Qtypes.ket or self.qtype == Qtypes.bra
496 def to_ket(self):
497 return to_ket(self)
499 def transpose(self, *args):
500 return transpose(self, *args)
502 def keep_only_diag_elements(self):
503 return keep_only_diag_elements(self)
505 # ----
507 # Math Functions ----
508 def unit(self):
509 return unit(self)
511 def norm(self):
512 return norm(self)
514 def expm(self):
515 return expm(self)
517 def powm(self, n):
518 return powm(self, n)
520 def cosm(self):
521 return cosm(self)
523 def sinm(self):
524 return sinm(self)
526 def tr(self, **kwargs):
527 return tr(self, **kwargs)
529 def trace(self, **kwargs):
530 return tr(self, **kwargs)
532 def ptrace(self, indx):
533 return ptrace(self, indx)
535 def eigenstates(self):
536 return eigenstates(self)
538 def eigenenergies(self):
539 return eigenenergies(self)
541 def collapse(self, mode="sum"):
542 return collapse(self, mode=mode)
544 # ----
547ARRAY_TYPES = (Array, ndarray, Qarray)
549# Qarray operations ---------------------------------------------------------------------
551def concatenate(qarr_list: List[Qarray], axis: int = 0) -> Qarray:
552 """Concatenate a list of Qarrays along a specified axis.
554 Args:
555 qarr_list (List[Qarray]): List of Qarrays to concatenate.
556 axis (int): Axis along which to concatenate. Default is 0.
558 Returns:
559 Qarray: Concatenated Qarray.
560 """
562 non_empty_qarr_list = [qarr for qarr in qarr_list if len(qarr.data) != 0]
564 if len(non_empty_qarr_list) == 0:
565 return Qarray.from_list([])
567 concatenated_data = jnp.concatenate(
568 [qarr.data for qarr in non_empty_qarr_list], axis=axis
569 )
571 dims = non_empty_qarr_list[0].dims
572 return Qarray.create(concatenated_data, dims=dims)
575def collapse(qarr: Qarray, mode="sum") -> Qarray:
576 """Collapse the Qarray.
578 Args:
579 qarr (Qarray): quantum array array
581 Returns:
582 Collapsed quantum array
583 """
584 if mode == "sum":
585 if len(qarr.bdims) == 0:
586 return qarr
588 batch_axes = list(range(len(qarr.bdims)))
589 return Qarray.create(jnp.sum(qarr.data, axis=batch_axes), dims=qarr.dims)
592def transpose(qarr: Qarray, indices: List[int]) -> Qarray:
593 """Transpose the quantum array.
595 Args:
596 qarr (Qarray): quantum array
597 *args: axes to transpose
599 Returns:
600 tranposed Qarray
601 """
603 indices = list(indices)
605 shaped_data = qarr.shaped_data
606 dims = qarr.dims
607 bdims_indxs = list(range(len(qarr.bdims)))
609 reshape_indices = indices + [j + len(dims[0]) for j in indices]
611 reshape_indices = bdims_indxs + [j + len(bdims_indxs) for j in reshape_indices]
613 shaped_data = shaped_data.transpose(reshape_indices)
614 new_dims = (
615 tuple([dims[0][j] for j in indices]),
616 tuple([dims[1][j] for j in indices]),
617 )
619 full_dims = prod(dims[0])
620 full_data = shaped_data.reshape(*qarr.bdims, full_dims, -1)
621 return Qarray.create(full_data, dims=new_dims)
624def unit(qarr: Qarray) -> Qarray:
625 """Normalize the quantum array.
627 Args:
628 qarr (Qarray): quantum array
630 Returns:
631 Normalized quantum array
632 """
633 return qarr / qarr.norm()
636def norm(qarr: Qarray) -> float:
637 qdata = qarr.data
638 bdims = qarr.bdims
640 if qarr.qtype == Qtypes.oper:
641 qdata_dag = qarr.dag().data
643 if len(bdims) > 0:
644 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1])
645 qdata_dag = qdata_dag.reshape(-1, qdata_dag.shape[-2], qdata_dag.shape[-1])
647 evals, _ = vmap(jnp.linalg.eigh)(qdata @ qdata_dag)
648 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)), axis=-1)
649 rho_norm = rho_norm.reshape(*bdims)
650 return rho_norm
651 else:
652 evals, _ = jnp.linalg.eigh(qdata @ qdata_dag)
653 rho_norm = jnp.sum(jnp.sqrt(jnp.abs(evals)))
654 return rho_norm
656 elif qarr.qtype in [Qtypes.ket, Qtypes.bra]:
657 if len(bdims) > 0:
658 qdata = qdata.reshape(-1, qdata.shape[-2], qdata.shape[-1])
659 return vmap(jnp.linalg.norm)(qdata).reshape(*bdims)
660 else:
661 return jnp.linalg.norm(qdata)
664def tensor(*args, **kwargs) -> Qarray:
665 """Tensor product.
667 Args:
668 *args (Qarray): tensors to take the product of
669 parallel (bool): if True, use parallel einsum for tensor product
670 true: [A,B] ^ [C,D] = [A^C, B^D]
671 false (default): [A,B] ^ [C,D] = [A^C, A^D, B^C, B^D]
673 Returns:
674 Tensor product of given tensors
676 """
678 parallel = kwargs.pop("parallel", False)
680 data = args[0].data
681 dims = deepcopy(args[0].dims)
682 dims_0 = dims[0]
683 dims_1 = dims[1]
684 for arg in args[1:]:
685 if parallel:
686 a = data
687 b = arg.data
689 if len(a.shape) > len(b.shape):
690 batch_dim = a.shape[:-2]
691 elif len(a.shape) == len(b.shape):
692 if prod(a.shape[:-2]) > prod(b.shape[:-2]):
693 batch_dim = a.shape[:-2]
694 else:
695 batch_dim = b.shape[:-2]
696 else:
697 batch_dim = b.shape[:-2]
699 data = jnp.einsum("...ij,...kl->...ikjl", a, b).reshape(
700 *batch_dim, a.shape[-2] * b.shape[-2], -1
701 )
702 else:
703 data = jnp.kron(data, arg.data, **kwargs)
705 dims_0 = dims_0 + arg.dims[0]
706 dims_1 = dims_1 + arg.dims[1]
708 return Qarray.create(data, dims=(dims_0, dims_1))
711def tr(qarr: Qarray, **kwargs) -> Array:
712 """Full trace.
714 Args:
715 qarr (Qarray): quantum array
717 Returns:
718 Full trace.
719 """
720 axis1 = kwargs.get("axis1", -2)
721 axis2 = kwargs.get("axis2", -1)
722 return jnp.trace(qarr.data, axis1=axis1, axis2=axis2, **kwargs)
725def trace(qarr: Qarray, **kwargs) -> Array:
726 """Full trace.
728 Args:
729 qarr (Qarray): quantum array
731 Returns:
732 Full trace.
733 """
734 return tr(qarr, **kwargs)
737def expm_data(data: Array, **kwargs) -> Array:
738 """Matrix exponential wrapper.
740 Returns:
741 matrix exponential
742 """
743 return jsp.linalg.expm(data, **kwargs)
746def expm(qarr: Qarray, **kwargs) -> Qarray:
747 """Matrix exponential wrapper.
749 Returns:
750 matrix exponential
751 """
752 dims = qarr.dims
753 data = expm_data(qarr.data, **kwargs)
754 return Qarray.create(data, dims=dims)
757def powm(qarr: Qarray, n: Union[int, float], clip_eigvals=False) -> Qarray:
758 """Matrix power.
760 Args:
761 qarr (Qarray): quantum array
762 n (int): power
763 clip_eigvals (bool): clip eigenvalues to always be able to compute
764 non-integer powers
766 Returns:
767 matrix power
768 """
769 if isinstance(n, int):
770 data_res = jnp.linalg.matrix_power(qarr.data, n)
771 else:
772 evalues, evectors = jnp.linalg.eig(qarr.data)
773 if clip_eigvals:
774 evalues = jnp.maximum(evalues, 0)
775 else:
776 if not (evalues >= 0).all():
777 raise ValueError(
778 "Non-integer power of a matrix can only be "
779 "computed if the matrix is positive semi-definite."
780 "Got a matrix with a negative eigenvalue."
781 )
782 data_res = evectors * jnp.pow(evalues, n) @ jnp.linalg.inv(evectors)
783 return Qarray.create(data_res, dims=qarr.dims)
786def cosm_data(data: Array, **kwargs) -> Array:
787 """Matrix cosine wrapper.
789 Returns:
790 matrix cosine
791 """
792 return (expm_data(1j * data) + expm_data(-1j * data)) / 2
795def cosm(qarr: Qarray) -> Qarray:
796 """Matrix cosine wrapper.
798 Args:
799 qarr (Qarray): quantum array
801 Returns:
802 matrix cosine
803 """
804 dims = qarr.dims
805 data = cosm_data(qarr.data)
806 return Qarray.create(data, dims=dims)
809def sinm_data(data: Array, **kwargs) -> Array:
810 """Matrix sine wrapper.
812 Args:
813 data: matrix
815 Returns:
816 matrix sine
817 """
818 return (expm_data(1j * data) - expm_data(-1j * data)) / (2j)
821def sinm(qarr: Qarray) -> Qarray:
822 dims = qarr.dims
823 data = sinm_data(qarr.data)
824 return Qarray.create(data, dims=dims)
827def keep_only_diag_elements(qarr: Qarray) -> Qarray:
828 if len(qarr.bdims) > 0:
829 raise ValueError("Cannot keep only diagonal elements of a batched Qarray.")
831 dims = qarr.dims
832 data = jnp.diag(jnp.diag(qarr.data))
833 return Qarray.create(data, dims=dims)
836def to_ket(qarr: Qarray) -> Qarray:
837 if qarr.qtype == Qtypes.ket:
838 return qarr
839 elif qarr.qtype == Qtypes.bra:
840 return qarr.dag()
841 else:
842 raise ValueError("Can only get ket from a ket or bra.")
845def eigenstates(qarr: Qarray) -> Qarray:
846 """Eigenstates of a quantum array.
848 Args:
849 qarr (Qarray): quantum array
851 Returns:
852 eigenvalues and eigenstates
853 """
855 evals, evecs = jnp.linalg.eigh(qarr.data)
856 idxs_sorted = jnp.argsort(evals, axis=-1)
858 dims = ket_from_op_dims(qarr.dims)
860 evals = jnp.take_along_axis(evals, idxs_sorted, axis=-1)
861 evecs = jnp.take_along_axis(evecs, idxs_sorted[..., None, :], axis=-1)
863 evecs = Qarray.create(
864 evecs,
865 dims=dims,
866 bdims=evecs.shape[:-1],
867 )
869 return evals, evecs
872def eigenenergies(qarr: Qarray) -> Array:
873 """Eigenvalues of a quantum array.
875 Args:
876 qarr (Qarray): quantum array
878 Returns:
879 eigenvalues
880 """
882 evals = jnp.linalg.eigvalsh(qarr.data)
883 return evals
886# More quantum specific -----------------------------------------------------
889def ptrace(qarr: Qarray, indx) -> Qarray:
890 """Partial Trace.
892 Args:
893 rho: density matrix
894 indx: index of quantum object to keep, rest will be partial traced out
896 Returns:
897 partial traced out density matrix
899 TODO: Fix weird tracing errors that arise with reshape
900 """
902 qarr = ket2dm(qarr)
903 rho = qarr.shaped_data
904 dims = qarr.dims
906 Nq = len(dims[0])
908 indxs = [indx, indx + Nq]
909 for j in range(Nq):
910 if j == indx:
911 continue
912 indxs.append(j)
913 indxs.append(j + Nq)
915 bdims = qarr.bdims
916 len_bdims = len(bdims)
917 bdims_indxs = list(range(len_bdims))
918 indxs = bdims_indxs + [j + len_bdims for j in indxs]
919 rho = rho.transpose(indxs)
921 for j in range(Nq - 1):
922 rho = jnp.trace(rho, axis1=2 + len_bdims, axis2=3 + len_bdims)
924 return Qarray.create(rho)
927def dag(qarr: Qarray) -> Qarray:
928 """Conjugate transpose.
930 Args:
931 qarr (Qarray): quantum array
933 Returns:
934 conjugate transpose of qarr
935 """
936 dims = qarr.dims[::-1]
938 data = dag_data(qarr.data)
940 return Qarray.create(data, dims=dims)
943def dag_data(arr: Array) -> Array:
944 """Conjugate transpose.
946 Args:
947 arr: operator
949 Returns:
950 conjugate of op, and transposes last two axes
951 """
952 # TODO: revisit this case...
953 if len(arr.shape) == 1:
954 return jnp.conj(arr)
956 return jnp.moveaxis(
957 jnp.conj(arr), -1, -2
958 ) # transposes last two axes, good for batching
961def ket2dm(qarr: Qarray) -> Qarray:
962 """Turns ket into density matrix.
963 Does nothing if already operator.
965 Args:
966 qarr (Qarray): qarr
968 Returns:
969 Density matrix
970 """
972 if qarr.qtype == Qtypes.oper:
973 return qarr
975 if qarr.qtype == Qtypes.bra:
976 qarr = qarr.dag()
978 return qarr @ qarr.dag()
981# Data level operations ----
984def is_dm_data(data: Array) -> bool:
985 """Check if data is a density matrix.
987 Args:
988 data: matrix
989 Returns:
990 True if data is a density matrix
991 """
992 return data.shape[-2] == data.shape[-1]
995def powm_data(data: Array, n: int) -> Array:
996 """Matrix power.
998 Args:
999 data: matrix
1000 n: power
1002 Returns:
1003 matrix power
1004 """
1005 return jnp.linalg.matrix_power(data, n)