Coverage for jaxquantum / codes / binomial.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-11 21:51 +0000

1""" 

2Cat Code Qubit 

3""" 

4 

5from typing import Tuple 

6 

7 

8from jaxquantum.utils.utils import comb 

9from jaxquantum.codes.base import BosonicQubit 

10import jaxquantum as jqt 

11 

12from jax import vmap 

13from jax import config 

14import jax.numpy as jnp 

15 

16config.update("jax_enable_x64", True) 

17 

18 

19class BinomialQubit(BosonicQubit): 

20 """ 

21 Binomial Qubit Class. 

22 """ 

23 

24 PARAMETERS = ["L", "G", "D"] 

25 

26 name = "binomial" 

27 

28 def _params_validation(self): 

29 super()._params_validation() 

30 

31 # notation https://arxiv.org/pdf/2010.08699.pdf 

32 if "L" not in self.params: 

33 self.params["L"] = 1 

34 if "G" not in self.params: 

35 self.params["G"] = 0 

36 if "D" not in self.params: 

37 self.params["D"] = 0 

38 

39 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

40 """ 

41 Construct basis states |+-x>, |+-y>, |+-z> 

42 """ 

43 N = self.params["N"] 

44 

45 L = self.params["L"] 

46 G = self.params["G"] 

47 D = self.params["D"] 

48 

49 S = L + G 

50 

51 M = jnp.max(jnp.array([L, G, 2 * D])) 

52 

53 def plus_z_gen(p): 

54 C = comb(M + 1, p) 

55 return jnp.sqrt(C) * jqt.basis(N, p * (S + 1)).data 

56 

57 plus_z = jnp.sum(vmap(plus_z_gen)(jnp.arange(0, M + 2, 2)), axis=0) 

58 plus_z = jqt.unit(jqt.Qarray.create(plus_z)) 

59 

60 def minus_z_gen(p): 

61 C = comb(M + 1, p) 

62 return jnp.sqrt(C) * jqt.basis(N, p * (S + 1)).data 

63 

64 minus_z = jnp.sum(vmap(minus_z_gen)(jnp.arange(1, M + 2, 2)), axis=0) 

65 minus_z = jqt.unit(jqt.Qarray.create(minus_z)) 

66 

67 return plus_z, minus_z