Coverage for jaxquantum/devices/analysis/sweeps.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 21:51 +0000

1"""Sweeping tools.""" 

2 

3from copy import deepcopy 

4import itertools 

5from tempfile import NamedTemporaryFile 

6import jax.numpy as jnp 

7from tqdm import tqdm 

8import os 

9 

10 

11def run_sweep( 

12 params, 

13 sweep_params, 

14 metrics_func, 

15 fixed_kwargs=None, 

16 data=None, 

17 is_parallel=False, 

18 save_file=None, 

19 data_save_mode="end", 

20 return_errors=False, 

21): 

22 """Run a sweep over a single parameter, or multiple parameters. 

23 

24 Args: 

25 params (dict): The base parameters to sweep over. 

26 sweep_params (dict): The parameters to sweep over. 

27 key: The parameter name. 

28 value: The list of values to sweep over. 

29 metrics_func (function): The function to evaluate the metrics. 

30 fixed_params (dict, optional): The fixed parameters to send into metrics_func. Defaults to None. 

31 data (dict, optional): The data to append to. Defaults to None. 

32 is_parallel (bool, optional): Whether to sweep through the sweep_params lists in parallel or through their cartesian product. Defaults to False. 

33 save_file (str, optional): The file to save the data to. Defaults to None, in which case data is saved to a temporary file, which will be deleted upon closing (e.g. during garbage collection). 

34 data_save_mode (str, optional): The mode to save the data. Defaults to None. 

35 Options are: 

36 "no" - don't save data 

37 "end" - save data at the end of the sweep 

38 "during" - save data during and at the end of the sweep 

39 Returns: 

40 dict: The data after the sweep. 

41 """ 

42 if data is None: 

43 data = {} 

44 run = -1 

45 else: 

46 run = max(data.keys()) 

47 

48 assert data_save_mode in ["no", "end", "during"], "Invalid data_save_mode." 

49 

50 if data_save_mode in ["during", "end"]: 

51 if isinstance(save_file, str): 

52 print("Saving data to: ", save_file) 

53 dirname = os.path.dirname(save_file) 

54 if not os.path.exists(dirname): 

55 os.makedirs(dirname) 

56 else: 

57 save_file = save_file or NamedTemporaryFile() 

58 print("Saving data to a temporary file: ", save_file.name) 

59 

60 fixed_kwargs = fixed_kwargs or {} 

61 

62 if is_parallel: 

63 sweep_length = len(list(sweep_params.values())[0]) 

64 assert [len(vals) == sweep_length for vals in sweep_params.values()], ( 

65 "Parallel sweep parameters must have the same length." 

66 ) 

67 

68 errors = [] 

69 try: 

70 for j in tqdm(range(sweep_length)): 

71 run += 1 

72 data[run] = {} 

73 data[run]["params"] = deepcopy(params) 

74 sweep_point_info = { 

75 "labels": [], 

76 "values": [], 

77 "indices": [], 

78 } 

79 for key, vals in sweep_params.items(): 

80 data[run]["params"][key] = vals[j] 

81 sweep_point_info["labels"].append(key) 

82 sweep_point_info["values"].append(vals[j]) 

83 sweep_point_info["indices"].append(j) 

84 data[run]["results"] = metrics_func(data[run]["params"], **fixed_kwargs) 

85 data[run]["sweep_point_info"] = sweep_point_info 

86 if data_save_mode == "during": 

87 jnp.savez( 

88 save_file, data=data, sweep_params=sweep_params, params=params 

89 ) 

90 except Exception as e: 

91 errors.append(str(e)) 

92 print("Error during run: ", errors[-1]) 

93 

94 try: 

95 if data_save_mode in ["during", "end"]: 

96 jnp.savez( 

97 save_file, 

98 data=data, 

99 sweep_params=sweep_params, 

100 params=params, 

101 error=None, 

102 ) 

103 except Exception as e: 

104 errors.append(str(e)) 

105 print("Error during saving: ", errors[-1]) 

106 

107 if return_errors: 

108 return data, errors 

109 else: 

110 return data 

111 else: 

112 # Product Sweep 

113 sweep_points = list(itertools.product(*list(sweep_params.values()))) 

114 sweep_points_indxs = list( 

115 itertools.product( 

116 *[list(range(len(vals))) for vals in list(sweep_params.values())] 

117 ) 

118 ) 

119 sweep_point_labels = list(sweep_params.keys()) 

120 

121 errors = [] 

122 

123 try: 

124 with tqdm(total=len(sweep_points)) as pbar: 

125 for j, sweep_point in enumerate(sweep_points): 

126 run += 1 

127 data[run] = {} 

128 data[run]["params"] = deepcopy(params) 

129 sweep_point_info = { 

130 "labels": [], 

131 "values": [], 

132 "indices": [], 

133 } 

134 for i, key in enumerate(sweep_point_labels): 

135 data[run]["params"][key] = sweep_point[i] 

136 sweep_point_info["labels"].append(key) 

137 sweep_point_info["values"].append(sweep_point[i]) 

138 sweep_point_info["indices"].append(sweep_points_indxs[j][i]) 

139 data[run]["results"] = metrics_func( 

140 data[run]["params"], **fixed_kwargs 

141 ) 

142 data[run]["sweep_point_info"] = sweep_point_info 

143 pbar.update(1) 

144 if data_save_mode == "during": 

145 jnp.savez( 

146 save_file, 

147 data=data, 

148 sweep_params=sweep_params, 

149 params=params, 

150 ) 

151 except Exception as e: 

152 errors.append(str(e)) 

153 print("Error during run: ", errors[-1]) 

154 

155 try: 

156 if data_save_mode in ["during", "end"]: 

157 jnp.savez( 

158 save_file, data=data, sweep_params=sweep_params, params=params 

159 ) 

160 except Exception as e: 

161 errors.append(str(e)) 

162 print("Error during saving: ", errors[-1]) 

163 

164 if return_errors: 

165 return data, errors 

166 else: 

167 return data