Coverage for jaxquantum/codes/cat.py: 50%

22 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6 

7from jaxquantum.codes.base import BosonicQubit 

8import jaxquantum as jqt 

9 

10from jax import config 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15class CatQubit(BosonicQubit): 

16 """ 

17 Cat Qubit Class. 

18 """ 

19 

20 name = "cat" 

21 

22 @property 

23 def _non_device_params(self): 

24 param_list = super()._non_device_params 

25 param_list.append("alpha") 

26 return param_list 

27 

28 def _params_validation(self): 

29 super()._params_validation() 

30 if "alpha" not in self.params: 

31 self.params["alpha"] = 2 

32 

33 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

34 """ 

35 Construct basis states |+-x>, |+-y>, |+-z> 

36 """ 

37 N = self.params["N"] 

38 a = self.params["alpha"] 

39 plus_z = jqt.unit(jqt.coherent(N, a) + jqt.coherent(N, -1.0 * a)) 

40 minus_z = jqt.unit(jqt.coherent(N, 1.0j * a) + jqt.coherent(N, -1.0j * a)) 

41 return plus_z, minus_z