diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py index 2e2fc64c..4b37973c 100644 --- a/src/twinkle/model/megatron/megatron.py +++ b/src/twinkle/model/megatron/megatron.py @@ -1645,36 +1645,27 @@ def weight_generator(): pass return - import queue - buf: queue.Queue = queue.Queue(maxsize=4) - error: list = [] + async def _send(): + await engine.send_weights(weight_generator()) - def _send(): + result_container = {'error': None} - def _iter(): - while (item := buf.get()) is not None: - yield item - - loop = asyncio.new_event_loop() + def _run(): try: - loop.run_until_complete(engine.send_weights(_iter())) - except Exception as exc: - error.append(exc) - finally: - loop.close() - - sender = threading.Thread(target=_send, name='ce-broadcast', daemon=True) - sender.start() - try: - for name, tensor in weight_generator(): - buf.put((name, tensor.clone())) - if error: - break - finally: - buf.put(None) # sentinel - sender.join() - if error: - raise error[0] + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(_send()) + finally: + loop.close() + except Exception as e: + result_container['error'] = e + + thread = threading.Thread(target=_run) + thread.start() + thread.join() + if result_container['error'] is not None: + raise result_container['error'] @remote_function(collect='first') def get_peft_config_dict(self, adapter_name: str = None) -> dict: