Coverage for jaxquantum / codes / cat.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 21:51 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6 

7from jaxquantum.codes.base import BosonicQubit 

8import jaxquantum as jqt 

9 

10from jax import config 

11 

12config.update("jax_enable_x64", True) 

13 

14 

15class CatQubit(BosonicQubit): 

16 """ 

17 Cat Qubit Class. 

18 """ 

19 

20 PARAMETERS = ["alpha"] 

21 

22 name = "cat" 

23 

24 @property 

25 def _non_device_params(self): 

26 param_list = super()._non_device_params 

27 param_list.append("alpha") 

28 return param_list 

29 

30 def _params_validation(self): 

31 super()._params_validation() 

32 if "alpha" not in self.params: 

33 self.params["alpha"] = 2 

34 

35 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

36 """ 

37 Construct basis states |+-x>, |+-y>, |+-z> 

38 """ 

39 N = self.params["N"] 

40 a = self.params["alpha"] 

41 plus_z = jqt.unit(jqt.coherent(N, a) + jqt.coherent(N, -1.0 * a)) 

42 minus_z = jqt.unit(jqt.coherent(N, 1.0j * a) + jqt.coherent(N, -1.0j * a)) 

43 return plus_z, minus_z