Coverage for jaxquantum/core/dims.py: 100%
80 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""dims."""
3from typing import List, Tuple
4from copy import deepcopy
5from math import prod
6from jax import Array
8from enum import Enum
11DIMS_TYPE = List[List[int]]
14def isket_dims(dims: DIMS_TYPE) -> bool:
15 return prod(dims[1]) == 1
18def isbra_dims(dims: DIMS_TYPE) -> bool:
19 return prod(dims[0]) == 1
22def isop_dims(dims: DIMS_TYPE) -> bool:
23 return prod(dims[1]) == prod(dims[0])
26def ket_from_op_dims(dims: DIMS_TYPE) -> DIMS_TYPE:
27 return (dims[0], tuple([1 for _ in dims[1]]))
30def check_dims(dims: Tuple[Tuple[int]], bdims: Tuple[int], data_shape: Array) -> bool:
31 if len(data_shape) == 1 and data_shape[0] == 0:
32 # E.g. empty list of operators
33 assert bdims == (0,)
34 assert dims == ((), ())
35 return
37 assert bdims == data_shape[:-2], "Data shape should be consistent with dimensions."
38 assert data_shape[-2] == prod(dims[0]), (
39 "Data shape should be consistent with dimensions."
40 )
41 assert data_shape[-1] == prod(dims[1]), (
42 "Data shape should be consistent with dimensions."
43 )
46class Qdims:
47 def __init__(self, dims):
48 self._dims = deepcopy(dims)
49 self._dims = (tuple(self._dims[0]), tuple(self._dims[1]))
50 self._qtype = Qtypes.from_dims(self._dims)
52 @property
53 def dims(self):
54 return self._dims
56 @property
57 def from_(self):
58 return self._dims[1]
60 @property
61 def to_(self):
62 return self._dims[0]
64 @property
65 def qtype(self):
66 return self._qtype
68 def __str__(self):
69 return str(self.dims)
71 def __repr__(self):
72 return self.__str__()
74 def __eq__(self, other):
75 return (self.dims == other.dims) and (self.qtype == other.qtype)
77 def __ne__(self, other):
78 return (self.dims != other.dims) or (self.qtype != other.qtype)
80 def __hash__(self):
81 return hash(self.dims)
83 def __matmul__(self, other):
84 if self.from_ != other.to_:
85 raise TypeError(f"incompatible dimensions {self} and {other}")
87 new_dims = [self.to_, other.from_]
88 return Qdims(new_dims)
91class Qtypes(str, Enum):
92 ket = "ket"
93 bra = "bra"
94 oper = "oper"
96 @classmethod
97 def from_dims(cls, dims: Array):
98 if isket_dims(dims):
99 return cls.ket
100 if isbra_dims(dims):
101 return cls.bra
102 if isop_dims(dims):
103 return cls.oper
104 raise ValueError("Invalid data shape")
106 @classmethod
107 def from_str(cls, string: str):
108 return cls(string)
110 def __str__(self):
111 return self.value
113 def __repr__(self):
114 return self.__str__()
116 def __eq__(self, other):
117 return self.value == other.value
119 def __ne__(self, other):
120 return self.value != other.value
122 def __hash__(self):
123 return hash(self.value)