|
17 | 17 | import concurrent.futures |
18 | 18 | from collections import defaultdict |
19 | 19 | from concurrent.futures import ThreadPoolExecutor |
20 | | -import time |
21 | 20 |
|
22 | 21 | import ray |
23 | 22 | import ray.experimental.state.api |
@@ -175,24 +174,18 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False): |
175 | 174 | episode_offset % sync_group.frequency == 0: |
176 | 175 | sync_group: ParameterSyncGroup = sync_group |
177 | 176 |
|
178 | | - start = time.perf_counter() |
179 | 177 | src_model, dst_model = sync_group.src_model, sync_group.dst_model |
180 | 178 | refs = src_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) |
181 | 179 | future.wait(refs) |
182 | 180 | refs = dst_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) |
183 | 181 | future.wait(refs) |
184 | | - logger.info(f"============In sync_parameters, onload {sync_group} elapsed {time.perf_counter() - start} s") |
185 | 182 |
|
186 | | - start = time.perf_counter() |
187 | 183 | sync_group.sync(requires_grad, validate) |
188 | | - logger.info(f"============In sync_parameters, synchronizing {sync_group} elapsed {time.perf_counter() - start} s") |
189 | 184 |
|
190 | | - start = time.perf_counter() |
191 | 185 | refs = src_model.offload() |
192 | 186 | future.wait(refs) |
193 | 187 | refs = dst_model.offload() |
194 | 188 | future.wait(refs) |
195 | | - logger.info(f"============In sync_parameters, offload {sync_group} elapsed {time.perf_counter() - start} s") |
196 | 189 |
|
197 | 190 | def set_func_decorator(self, model): |
198 | 191 | if is_decorated(model.name): |
|
0 commit comments