Coverage for jaxquantum/codes/qubit.py: 50%

36 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1""" 

2Qubit 

3""" 

4 

5from typing import Tuple 

6import warnings 

7 

8from jaxquantum.codes.base import BosonicQubit 

9import jaxquantum as jqt 

10 

11from jax import config 

12import matplotlib.pyplot as plt 

13import qutip as qt 

14 

15config.update("jax_enable_x64", True) 

16 

17 

18class Qubit(BosonicQubit): 

19 """ 

20 FockQubit 

21 """ 

22 

23 def _params_validation(self): 

24 super()._params_validation() 

25 self.params["N"] = 2 

26 

27 def _get_basis_z(self) -> Tuple[jqt.Qarray, jqt.Qarray]: 

28 """ 

29 Construct basis states |+-x>, |+-y>, |+-z> 

30 """ 

31 N = int(self.params["N"]) 

32 plus_z = jqt.basis(N, 0) 

33 minus_z = jqt.basis(N, 1) 

34 return plus_z, minus_z 

35 

36 @property 

37 def x_U(self) -> jqt.Qarray: 

38 return jqt.sigmax() 

39 

40 @property 

41 def y_U(self) -> jqt.Qarray: 

42 return jqt.sigmay() 

43 

44 @property 

45 def z_U(self) -> jqt.Qarray: 

46 return jqt.sigmaz() 

47 

48 def plot(self, state, ax=None, qp_type="", **kwargs) -> None: 

49 state = self.jqt2qt(state) 

50 with warnings.catch_warnings(): 

51 # TODO: suppressing deprecation warnings, deal with this 

52 warnings.simplefilter("ignore") 

53 b = qt.Bloch() 

54 b.add_states(state) 

55 b.render() 

56 b.show() 

57 plt.tight_layout() 

58 plt.show()