@@ -201,17 +201,24 @@ def init_model(self, kvargs):
201201 )
202202
203203 # 用于协同读取 ShmObjsIOBuffer 中的请求信息的通信tensor和通信组对象。
204- self .node_broadcast_tensor = torch .tensor ([0 ], dtype = torch .int32 , device = "cuda" , requires_grad = False )
205- self .node_nccl_group = create_new_group_for_current_node ("nccl" )
204+ # nccl频繁小数据通信会导致内存泄露, 这里需要换gloo
205+ self .node_broadcast_tensor_cpu = torch .empty ((1 ,), dtype = torch .int32 , device = "cpu" , pin_memory = True )
206+ self .node_gloo_group = create_new_group_for_current_node ("gloo" )
206207
207208 # 用于在多节点tp模式下协同读取 ShmObjsIOBuffer 中的请求信息的通信tensor和通信组对象。
208209 if self .is_multinode_tp :
209210 self .multinode_tp_gather_item_tensor = torch .tensor ([0 ], dtype = torch .int32 , device = "cuda" )
210211 self .multinode_tp_all_gather_tensor = torch .tensor (
211212 [0 for _ in range (self .global_world_size )], dtype = torch .int32 , device = "cuda" , requires_grad = False
212213 )
213- self .multinode_tp_nccl_group = dist .new_group (
214- [rank for rank in range (self .global_world_size )], backend = "nccl"
214+ self .multinode_tp_gather_item_tensor_cpu = torch .empty (
215+ (1 ,), dtype = torch .int32 , device = "cpu" , pin_memory = True
216+ )
217+ self .multinode_tp_all_gather_tensor_cpu = torch .empty (
218+ (self .global_world_size ,), dtype = torch .int32 , device = "cpu" , pin_memory = True
219+ )
220+ self .multinode_tp_gloo_group = dist .new_group (
221+ [rank for rank in range (self .global_world_size )], backend = "gloo"
215222 )
216223
217224 if (
@@ -221,7 +228,7 @@ def init_model(self, kvargs):
221228 # 如果存在需要跨进程使用mem manger的特性,则将mem manager写入到 shm中,方便
222229 # 读取
223230 self .model .mem_manager .write_to_shm (req_manager = self .model .req_manager )
224- dist .barrier (group = self .node_nccl_group )
231+ dist .barrier (group = self .node_gloo_group )
225232
226233 self .init_custom ()
227234
@@ -362,28 +369,22 @@ def _try_read_new_reqs(self):
362369
363370 def _try_read_new_reqs_normal (self ):
364371 if self .is_master_in_node :
365- if self .shm_reqs_io_buffer .is_ready ():
366- self .node_broadcast_tensor .fill_ (1 )
367- else :
368- self .node_broadcast_tensor .fill_ (0 )
372+ self .node_broadcast_tensor_cpu [0 ] = 1 if self .shm_reqs_io_buffer .is_ready () else 0
369373
370374 src_rank_id = self .args .node_rank * self .node_world_size
371- dist .broadcast (self .node_broadcast_tensor , src = src_rank_id , group = self .node_nccl_group , async_op = False )
372- new_buffer_is_ready = self .node_broadcast_tensor . detach (). item ()
375+ dist .broadcast (self .node_broadcast_tensor_cpu , src = src_rank_id , group = self .node_gloo_group , async_op = False )
376+ new_buffer_is_ready = int ( self .node_broadcast_tensor_cpu [ 0 ]. item () )
373377 if new_buffer_is_ready :
374378 self ._read_reqs_buffer_and_init_reqs ()
375379
376380 # nixl pd mode 从 shm_nixl_trans_io_buffer 读取分块传输的完成进度。
377381 if self .is_nixl_pd_mode :
378382 if self .is_master_in_node :
379- if self .shm_nixl_trans_io_buffer .is_ready ():
380- self .node_broadcast_tensor .fill_ (1 )
381- else :
382- self .node_broadcast_tensor .fill_ (0 )
383+ self .node_broadcast_tensor_cpu [0 ] = 1 if self .shm_nixl_trans_io_buffer .is_ready () else 0
383384
384385 src_rank_id = self .args .node_rank * self .node_world_size
385- dist .broadcast (self .node_broadcast_tensor , src = src_rank_id , group = self .node_nccl_group , async_op = False )
386- new_buffer_is_ready = self .node_broadcast_tensor . detach (). item ()
386+ dist .broadcast (self .node_broadcast_tensor_cpu , src = src_rank_id , group = self .node_gloo_group , async_op = False )
387+ new_buffer_is_ready = int ( self .node_broadcast_tensor_cpu [ 0 ]. item () )
387388 if new_buffer_is_ready :
388389 self ._read_nixl_trans_io_buffer_and_update_req_status ()
389390 return
@@ -392,17 +393,14 @@ def _try_read_new_reqs_multinode_tp(self):
392393 """
393394 多节点tp模式下,需要协调所有rank的行为同步。
394395 """
395- if self .shm_reqs_io_buffer .is_ready ():
396- self .multinode_tp_gather_item_tensor .fill_ (1 )
397- else :
398- self .multinode_tp_gather_item_tensor .fill_ (0 )
396+ self .multinode_tp_gather_item_tensor_cpu [0 ] = 1 if self .shm_reqs_io_buffer .is_ready () else 0
399397 dist .all_gather_into_tensor (
400- self .multinode_tp_all_gather_tensor ,
401- self .multinode_tp_gather_item_tensor ,
402- group = self .multinode_tp_nccl_group ,
398+ self .multinode_tp_all_gather_tensor_cpu ,
399+ self .multinode_tp_gather_item_tensor_cpu ,
400+ group = self .multinode_tp_gloo_group ,
403401 async_op = False ,
404402 )
405- new_buffer_is_readys = self .multinode_tp_all_gather_tensor . detach (). cpu () .numpy ()
403+ new_buffer_is_readys = self .multinode_tp_all_gather_tensor_cpu .numpy ()
406404 new_buffer_is_ready = np .all (new_buffer_is_readys == 1 )
407405
408406 if new_buffer_is_ready :
0 commit comments