File tree Expand file tree Collapse file tree 1 file changed +18
-27
lines changed
src/twinkle/model/megatron Expand file tree Collapse file tree 1 file changed +18
-27
lines changed Original file line number Diff line number Diff line change @@ -1645,36 +1645,27 @@ def weight_generator():
16451645 pass
16461646 return
16471647
1648- import queue
1649- buf : queue .Queue = queue .Queue (maxsize = 4 )
1650- error : list = []
1648+ async def _send ():
1649+ await engine .send_weights (weight_generator ())
16511650
1652- def _send ():
1651+ result_container = { 'error' : None }
16531652
1654- def _iter ():
1655- while (item := buf .get ()) is not None :
1656- yield item
1657-
1658- loop = asyncio .new_event_loop ()
1653+ def _run ():
16591654 try :
1660- loop .run_until_complete (engine .send_weights (_iter ()))
1661- except Exception as exc :
1662- error .append (exc )
1663- finally :
1664- loop .close ()
1665-
1666- sender = threading .Thread (target = _send , name = 'ce-broadcast' , daemon = True )
1667- sender .start ()
1668- try :
1669- for name , tensor in weight_generator ():
1670- buf .put ((name , tensor .clone ()))
1671- if error :
1672- break
1673- finally :
1674- buf .put (None ) # sentinel
1675- sender .join ()
1676- if error :
1677- raise error [0 ]
1655+ loop = asyncio .new_event_loop ()
1656+ asyncio .set_event_loop (loop )
1657+ try :
1658+ loop .run_until_complete (_send ())
1659+ finally :
1660+ loop .close ()
1661+ except Exception as e :
1662+ result_container ['error' ] = e
1663+
1664+ thread = threading .Thread (target = _run )
1665+ thread .start ()
1666+ thread .join ()
1667+ if result_container ['error' ] is not None :
1668+ raise result_container ['error' ]
16781669
16791670 @remote_function (collect = 'first' )
16801671 def get_peft_config_dict (self , adapter_name : str = None ) -> dict :
You can’t perform that action at this time.
0 commit comments