Coverage for jaxquantum/core/dims.py: 100%

80 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""dims.""" 

2 

3from typing import List, Tuple 

4from copy import deepcopy 

5from math import prod 

6from jax import Array 

7 

8from enum import Enum 

9 

10 

11DIMS_TYPE = List[List[int]] 

12 

13 

14def isket_dims(dims: DIMS_TYPE) -> bool: 

15 return prod(dims[1]) == 1 

16 

17 

18def isbra_dims(dims: DIMS_TYPE) -> bool: 

19 return prod(dims[0]) == 1 

20 

21 

22def isop_dims(dims: DIMS_TYPE) -> bool: 

23 return prod(dims[1]) == prod(dims[0]) 

24 

25 

26def ket_from_op_dims(dims: DIMS_TYPE) -> DIMS_TYPE: 

27 return (dims[0], tuple([1 for _ in dims[1]])) 

28 

29 

30def check_dims(dims: Tuple[Tuple[int]], bdims: Tuple[int], data_shape: Array) -> bool: 

31 if len(data_shape) == 1 and data_shape[0] == 0: 

32 # E.g. empty list of operators 

33 assert bdims == (0,) 

34 assert dims == ((), ()) 

35 return 

36 

37 assert bdims == data_shape[:-2], "Data shape should be consistent with dimensions." 

38 assert data_shape[-2] == prod(dims[0]), ( 

39 "Data shape should be consistent with dimensions." 

40 ) 

41 assert data_shape[-1] == prod(dims[1]), ( 

42 "Data shape should be consistent with dimensions." 

43 ) 

44 

45 

46class Qdims: 

47 def __init__(self, dims): 

48 self._dims = deepcopy(dims) 

49 self._dims = (tuple(self._dims[0]), tuple(self._dims[1])) 

50 self._qtype = Qtypes.from_dims(self._dims) 

51 

52 @property 

53 def dims(self): 

54 return self._dims 

55 

56 @property 

57 def from_(self): 

58 return self._dims[1] 

59 

60 @property 

61 def to_(self): 

62 return self._dims[0] 

63 

64 @property 

65 def qtype(self): 

66 return self._qtype 

67 

68 def __str__(self): 

69 return str(self.dims) 

70 

71 def __repr__(self): 

72 return self.__str__() 

73 

74 def __eq__(self, other): 

75 return (self.dims == other.dims) and (self.qtype == other.qtype) 

76 

77 def __ne__(self, other): 

78 return (self.dims != other.dims) or (self.qtype != other.qtype) 

79 

80 def __hash__(self): 

81 return hash(self.dims) 

82 

83 def __matmul__(self, other): 

84 if self.from_ != other.to_: 

85 raise TypeError(f"incompatible dimensions {self} and {other}") 

86 

87 new_dims = [self.to_, other.from_] 

88 return Qdims(new_dims) 

89 

90 

91class Qtypes(str, Enum): 

92 ket = "ket" 

93 bra = "bra" 

94 oper = "oper" 

95 

96 @classmethod 

97 def from_dims(cls, dims: Array): 

98 if isket_dims(dims): 

99 return cls.ket 

100 if isbra_dims(dims): 

101 return cls.bra 

102 if isop_dims(dims): 

103 return cls.oper 

104 raise ValueError("Invalid data shape") 

105 

106 @classmethod 

107 def from_str(cls, string: str): 

108 return cls(string) 

109 

110 def __str__(self): 

111 return self.value 

112 

113 def __repr__(self): 

114 return self.__str__() 

115 

116 def __eq__(self, other): 

117 return self.value == other.value 

118 

119 def __ne__(self, other): 

120 return self.value != other.value 

121 

122 def __hash__(self): 

123 return hash(self.value)