Coverage for jaxquantum/devices/analysis/sweeps.py: 0%
85 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 21:51 +0000
1"""Sweeping tools."""
3from copy import deepcopy
4import itertools
5from tempfile import NamedTemporaryFile
6import jax.numpy as jnp
7from tqdm import tqdm
8import os
11def run_sweep(
12 params,
13 sweep_params,
14 metrics_func,
15 fixed_kwargs=None,
16 data=None,
17 is_parallel=False,
18 save_file=None,
19 data_save_mode="end",
20 return_errors=False,
21):
22 """Run a sweep over a single parameter, or multiple parameters.
24 Args:
25 params (dict): The base parameters to sweep over.
26 sweep_params (dict): The parameters to sweep over.
27 key: The parameter name.
28 value: The list of values to sweep over.
29 metrics_func (function): The function to evaluate the metrics.
30 fixed_params (dict, optional): The fixed parameters to send into metrics_func. Defaults to None.
31 data (dict, optional): The data to append to. Defaults to None.
32 is_parallel (bool, optional): Whether to sweep through the sweep_params lists in parallel or through their cartesian product. Defaults to False.
33 save_file (str, optional): The file to save the data to. Defaults to None, in which case data is saved to a temporary file, which will be deleted upon closing (e.g. during garbage collection).
34 data_save_mode (str, optional): The mode to save the data. Defaults to None.
35 Options are:
36 "no" - don't save data
37 "end" - save data at the end of the sweep
38 "during" - save data during and at the end of the sweep
39 Returns:
40 dict: The data after the sweep.
41 """
42 if data is None:
43 data = {}
44 run = -1
45 else:
46 run = max(data.keys())
48 assert data_save_mode in ["no", "end", "during"], "Invalid data_save_mode."
50 if data_save_mode in ["during", "end"]:
51 if isinstance(save_file, str):
52 print("Saving data to: ", save_file)
53 dirname = os.path.dirname(save_file)
54 if not os.path.exists(dirname):
55 os.makedirs(dirname)
56 else:
57 save_file = save_file or NamedTemporaryFile()
58 print("Saving data to a temporary file: ", save_file.name)
60 fixed_kwargs = fixed_kwargs or {}
62 if is_parallel:
63 sweep_length = len(list(sweep_params.values())[0])
64 assert [len(vals) == sweep_length for vals in sweep_params.values()], (
65 "Parallel sweep parameters must have the same length."
66 )
68 errors = []
69 try:
70 for j in tqdm(range(sweep_length)):
71 run += 1
72 data[run] = {}
73 data[run]["params"] = deepcopy(params)
74 sweep_point_info = {
75 "labels": [],
76 "values": [],
77 "indices": [],
78 }
79 for key, vals in sweep_params.items():
80 data[run]["params"][key] = vals[j]
81 sweep_point_info["labels"].append(key)
82 sweep_point_info["values"].append(vals[j])
83 sweep_point_info["indices"].append(j)
84 data[run]["results"] = metrics_func(data[run]["params"], **fixed_kwargs)
85 data[run]["sweep_point_info"] = sweep_point_info
86 if data_save_mode == "during":
87 jnp.savez(
88 save_file, data=data, sweep_params=sweep_params, params=params
89 )
90 except Exception as e:
91 errors.append(str(e))
92 print("Error during run: ", errors[-1])
94 try:
95 if data_save_mode in ["during", "end"]:
96 jnp.savez(
97 save_file,
98 data=data,
99 sweep_params=sweep_params,
100 params=params,
101 error=None,
102 )
103 except Exception as e:
104 errors.append(str(e))
105 print("Error during saving: ", errors[-1])
107 if return_errors:
108 return data, errors
109 else:
110 return data
111 else:
112 # Product Sweep
113 sweep_points = list(itertools.product(*list(sweep_params.values())))
114 sweep_points_indxs = list(
115 itertools.product(
116 *[list(range(len(vals))) for vals in list(sweep_params.values())]
117 )
118 )
119 sweep_point_labels = list(sweep_params.keys())
121 errors = []
123 try:
124 with tqdm(total=len(sweep_points)) as pbar:
125 for j, sweep_point in enumerate(sweep_points):
126 run += 1
127 data[run] = {}
128 data[run]["params"] = deepcopy(params)
129 sweep_point_info = {
130 "labels": [],
131 "values": [],
132 "indices": [],
133 }
134 for i, key in enumerate(sweep_point_labels):
135 data[run]["params"][key] = sweep_point[i]
136 sweep_point_info["labels"].append(key)
137 sweep_point_info["values"].append(sweep_point[i])
138 sweep_point_info["indices"].append(sweep_points_indxs[j][i])
139 data[run]["results"] = metrics_func(
140 data[run]["params"], **fixed_kwargs
141 )
142 data[run]["sweep_point_info"] = sweep_point_info
143 pbar.update(1)
144 if data_save_mode == "during":
145 jnp.savez(
146 save_file,
147 data=data,
148 sweep_params=sweep_params,
149 params=params,
150 )
151 except Exception as e:
152 errors.append(str(e))
153 print("Error during run: ", errors[-1])
155 try:
156 if data_save_mode in ["during", "end"]:
157 jnp.savez(
158 save_file, data=data, sweep_params=sweep_params, params=params
159 )
160 except Exception as e:
161 errors.append(str(e))
162 print("Error during saving: ", errors[-1])
164 if return_errors:
165 return data, errors
166 else:
167 return data