Coverage for jaxquantum/codes/binomial.py: 33%

36 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6 

7 

8from jaxquantum.utils.utils import comb 

9from jaxquantum.codes.base import BosonicQubit 

10import jaxquantum as jqt 

11 

12from jax import vmap 

13from jax import config 

14import jax.numpy as jnp 

15 

16config.update("jax_enable_x64", True) 

17 

18 

19class BinomialQubit(BosonicQubit): 

20 """ 

21 Cat Qubit Class. 

22 """ 

23 

24 name = "binomial" 

25 

26 def _params_validation(self): 

27 super()._params_validation() 

28 

29 # notation https://arxiv.org/pdf/2010.08699.pdf 

30 if "L" not in self.params: 

31 self.params["L"] = 1 

32 if "G" not in self.params: 

33 self.params["G"] = 0 

34 if "D" not in self.params: 

35 self.params["D"] = 0 

36 

37 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

38 """ 

39 Construct basis states |+-x>, |+-y>, |+-z> 

40 """ 

41 N = self.params["N"] 

42 

43 L = self.params["L"] 

44 G = self.params["G"] 

45 D = self.params["D"] 

46 

47 S = L + G 

48 

49 M = jnp.max(jnp.array([L, G, 2 * D])) 

50 

51 def plus_z_gen(p): 

52 C = comb(M + 1, p) 

53 return jnp.sqrt(C) * jqt.basis(N, p * (S + 1)).data 

54 

55 plus_z = jnp.sum(vmap(plus_z_gen)(jnp.arange(0, M + 2, 2)), axis=0) 

56 plus_z = jqt.unit(jqt.Qarray.create(plus_z)) 

57 

58 def minus_z_gen(p): 

59 C = comb(M + 1, p) 

60 return jnp.sqrt(C) * jqt.basis(N, p * (S + 1)).data 

61 

62 minus_z = jnp.sum(vmap(minus_z_gen)(jnp.arange(1, M + 2, 2)), axis=0) 

63 minus_z = jqt.unit(jqt.Qarray.create(minus_z)) 

64 

65 return plus_z, minus_z