Skip to content

Commit 4d6ebeb

Browse files
authored
fix weight sync (#128)
1 parent 3ea0e88 commit 4d6ebeb

File tree

1 file changed

+18
-27
lines changed

1 file changed

+18
-27
lines changed

src/twinkle/model/megatron/megatron.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,36 +1645,27 @@ def weight_generator():
16451645
pass
16461646
return
16471647

1648-
import queue
1649-
buf: queue.Queue = queue.Queue(maxsize=4)
1650-
error: list = []
1648+
async def _send():
1649+
await engine.send_weights(weight_generator())
16511650

1652-
def _send():
1651+
result_container = {'error': None}
16531652

1654-
def _iter():
1655-
while (item := buf.get()) is not None:
1656-
yield item
1657-
1658-
loop = asyncio.new_event_loop()
1653+
def _run():
16591654
try:
1660-
loop.run_until_complete(engine.send_weights(_iter()))
1661-
except Exception as exc:
1662-
error.append(exc)
1663-
finally:
1664-
loop.close()
1665-
1666-
sender = threading.Thread(target=_send, name='ce-broadcast', daemon=True)
1667-
sender.start()
1668-
try:
1669-
for name, tensor in weight_generator():
1670-
buf.put((name, tensor.clone()))
1671-
if error:
1672-
break
1673-
finally:
1674-
buf.put(None) # sentinel
1675-
sender.join()
1676-
if error:
1677-
raise error[0]
1655+
loop = asyncio.new_event_loop()
1656+
asyncio.set_event_loop(loop)
1657+
try:
1658+
loop.run_until_complete(_send())
1659+
finally:
1660+
loop.close()
1661+
except Exception as e:
1662+
result_container['error'] = e
1663+
1664+
thread = threading.Thread(target=_run)
1665+
thread.start()
1666+
thread.join()
1667+
if result_container['error'] is not None:
1668+
raise result_container['error']
16781669

16791670
@remote_function(collect='first')
16801671
def get_peft_config_dict(self, adapter_name: str = None) -> dict:

0 commit comments

Comments
 (0)