5252from .fleet .layers .mpu .mp_ops import _linear
5353from .fleet .layers .mpu .mp_ops import _parallel_linear
5454from .fleet .layers .mpu .mp_ops import _parallel_embedding
55- from .communication .comm_utils import ReduceOp
55+ from .communication .group import Group , _add_new_group
56+ from .communication .all_reduce import all_reduce
57+ from .communication .reduce import _get_reduce_op , ReduceOp
5658
5759__all__ = []
5860
59-
60- class Group ():
61- """
62- The abstract representation of group.
63- """
64-
65- def __init__ (self , rank , rank_num , id = 0 , ranks = [], pg = None , name = None ):
66- self .rank = rank
67- self .nranks = rank_num
68- self .id = id
69- self .ranks = ranks
70- self .pg = pg
71- self .name = name
72-
73- def is_member (self ):
74- if self .rank < 0 :
75- return False
76- if self .nranks < 2 :
77- return False
78- return True
79-
80- def get_group_rank (self , rank ):
81- if self .is_member () and rank in self .ranks :
82- return self .ranks .index (rank )
83- else :
84- return - 1
85-
86- @property
87- def process_group (self ):
88- return self .pg
89-
90- @property
91- def world_size (self ):
92- return self .nranks if self .rank >= 0 else - 1
93-
94- def __repr__ (self ):
95- debug_str = "rank: {}, nranks: {}, id: {}, ranks: " .format (
96- self .rank , self .nranks , self .id )
97- debug_str += ", " .join (map (str , self .ranks ))
98- debug_str += "; name: "
99- debug_str += self .name if self .name else "None"
100- return debug_str
101-
102-
10361_global_env = None
10462
10563
@@ -147,9 +105,8 @@ def _get_group_map():
147105 global _group_map
148106 if _global_env_gid not in _group_map :
149107 genv = _get_global_env ()
150- _group_map [_global_env_gid ] = Group (genv .rank ,
151- genv .world_size ,
152- ranks = list (range (genv .world_size )))
108+ _group_map [_global_env_gid ] = Group (genv .rank , 0 ,
109+ list (range (genv .world_size )))
153110 return _group_map
154111
155112
@@ -197,19 +154,6 @@ def _new_ring_id():
197154 return len (_get_group_map ()) + max (_get_global_env ().nrings , 9 )
198155
199156
200- def _get_reduce_op (reduce_op , func_name ):
201- if reduce_op == ReduceOp .SUM :
202- return core .ReduceOp .SUM
203- elif reduce_op == ReduceOp .MAX :
204- return core .ReduceOp .MAX
205- elif reduce_op == ReduceOp .MIN :
206- return core .ReduceOp .MIN
207- elif reduce_op == ReduceOp .PROD :
208- return core .ReduceOp .PRODUCT
209- else :
210- raise ValueError ("Unknown reduce_op type for {}." .format (func_name ))
211-
212-
213157def get_group (id = 0 ):
214158 """
215159
@@ -451,10 +395,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
451395 else :
452396 rank = - 1
453397 pg = None
454- group = Group (rank , size , id = gid , ranks = ranks , pg = pg , name = group_name )
398+ group = Group (rank , gid , ranks , pg = pg , name = group_name )
455399 _group_map_by_name [group_name ] = group
456400 _group_map [gid ] = group
457401 _group_map_backend [group ] = backend
402+ #TODO: The method below is a new method for group management, will replace the previous
403+ # three in the future.
404+ _add_new_group (group )
458405
459406 # TODO(shenliang03): This is a temporary solution to solve the problem of
460407 # hang caused by tcp
@@ -476,13 +423,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
476423 ring_id = _new_ring_id ()
477424
478425 if global_rank not in ranks :
479- gp = Group (- 1 , - 1 , ring_id , ranks )
426+ gp = Group (- 1 , ring_id , ranks )
480427 _group_map [ring_id ] = gp
481428 else :
482429 ranks = sorted (ranks )
483430 group_rank = ranks .index (global_rank )
484431 group_size = len (ranks )
485- gp = Group (group_rank , group_size , ring_id , ranks )
432+ gp = Group (group_rank , ring_id , ranks )
486433 _group_map [ring_id ] = gp
487434
488435 if group_size >= 2 :
@@ -748,104 +695,6 @@ def broadcast(tensor, src, group=None, sync_op=True):
748695 })
749696
750697
751- def all_reduce (tensor , op = ReduceOp .SUM , group = None , sync_op = True ):
752- """
753-
754- Reduce a tensor over all ranks so that all get the result.
755- As shown below, one process is started with a GPU and the data of this process is represented
756- by its group rank. The reduce operator is sum. Through all_reduce operator,
757- each GPU will have the sum of the data from all GPUs.
758-
759- .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
760- :width: 800
761- :alt: all_reduce
762- :align: center
763-
764- Args:
765- tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
766- should be float16, float32, float64, int32, int64, int8, uint8 or bool.
767- op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
768- group (Group, optional): The group instance return by new_group or None for global default group.
769- sync_op (bool, optional): Wether this op is a sync op. Default value is True.
770-
771- Returns:
772- None.
773-
774- Examples:
775- .. code-block:: python
776-
777- # required: distributed
778- import paddle
779- import paddle.distributed as dist
780-
781- dist.init_parallel_env()
782- if dist.get_rank() == 0:
783- data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
784- else:
785- data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
786- dist.all_reduce(data)
787- print(data)
788- # [[5, 7, 9], [5, 7, 9]] (2 GPUs)
789- """
790- if group is not None and not group .is_member ():
791- return
792-
793- if in_dygraph_mode ():
794- op_type = _get_reduce_op (op , "all_reduce" )
795- group = _get_default_group () if group is None else group
796- task = group .process_group .allreduce (tensor , op_type )
797- if sync_op :
798- task .wait ()
799- return None
800- else :
801- return task
802-
803- use_calc_stream = sync_op
804- ring_id = 0 if group is None else group .id
805- if _non_static_mode ():
806- if op == ReduceOp .SUM :
807- return _legacy_C_ops .c_allreduce_sum_ (tensor , 'use_calc_stream' ,
808- use_calc_stream , 'ring_id' ,
809- ring_id )
810- elif op == ReduceOp .MAX :
811- return _legacy_C_ops .c_allreduce_max_ (tensor , 'use_calc_stream' ,
812- use_calc_stream , 'ring_id' ,
813- ring_id )
814- elif op == ReduceOp .MIN :
815- return _legacy_C_ops .c_allreduce_min_ (tensor , 'use_calc_stream' ,
816- use_calc_stream , 'ring_id' ,
817- ring_id )
818- elif op == ReduceOp .PROD :
819- return _legacy_C_ops .c_allreduce_prod_ (tensor , 'use_calc_stream' ,
820- use_calc_stream , 'ring_id' ,
821- ring_id )
822- else :
823- raise ValueError ("Unknown parameter: {}." .format (op ))
824-
825- check_variable_and_dtype (tensor , 'tensor' , [
826- 'float16' , 'float32' , 'float64' , 'int32' , 'int64' , 'int8' , 'uint8' ,
827- 'bool'
828- ], 'all_reduce' )
829- if op == ReduceOp .SUM :
830- op_type = 'c_allreduce_sum'
831- elif op == ReduceOp .MAX :
832- op_type = 'c_allreduce_max'
833- elif op == ReduceOp .MIN :
834- op_type = 'c_allreduce_min'
835- elif op == ReduceOp .PROD :
836- op_type = 'c_allreduce_prod'
837- if not isinstance (ring_id , int ):
838- raise ValueError ("The type of 'ring_id' for all_reduce should be int." )
839- helper = LayerHelper (op_type , ** locals ())
840- helper .append_op (type = op_type ,
841- inputs = {'X' : [tensor ]},
842- outputs = {'Out' : [tensor ]},
843- attrs = {
844- 'ring_id' : ring_id ,
845- 'use_calc_stream' : use_calc_stream
846- })
847-
848-
849698def reduce (tensor , dst , op = ReduceOp .SUM , group = None , sync_op = True ):
850699 """
851700
0 commit comments