Skip to content

Commit 868b1e1

Browse files
LiYuRioHermitSun
authored andcommitted
Move group and all reduce from collective to communication (#45848)
1 parent 9641b93 commit 868b1e1

11 files changed

Lines changed: 276 additions & 184 deletions

File tree

paddle/fluid/distributed/collective/ProcessGroupGloo.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
293293
std::vector<phi::DenseTensor>& inputs,
294294
std::vector<phi::DenseTensor>& outputs,
295295
const AllreduceOptions& opts) {
296+
return AllReduce(inputs, outputs, opts, true);
297+
}
298+
299+
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
300+
std::vector<phi::DenseTensor>& inputs,
301+
std::vector<phi::DenseTensor>& outputs,
302+
const AllreduceOptions& opts,
303+
bool sync_op) {
296304
auto tag = next_tag();
297305
std::shared_ptr<GlooTask> task;
298306
auto context = get_context();

paddle/fluid/distributed/collective/ProcessGroupGloo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ class ProcessGroupGloo : public ProcessGroup {
120120
std::vector<phi::DenseTensor>& outputs,
121121
const AllreduceOptions& opts = AllreduceOptions()) override;
122122

123+
std::shared_ptr<ProcessGroup::Task> AllReduce(
124+
std::vector<phi::DenseTensor>& inputs,
125+
std::vector<phi::DenseTensor>& outputs,
126+
const AllreduceOptions& opts,
127+
bool sync_op) override;
128+
123129
std::shared_ptr<ProcessGroup::Task> Barrier(
124130
const BarrierOptions& = BarrierOptions()) override;
125131

python/paddle/distributed/collective.py

Lines changed: 11 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -52,54 +52,12 @@
5252
from .fleet.layers.mpu.mp_ops import _linear
5353
from .fleet.layers.mpu.mp_ops import _parallel_linear
5454
from .fleet.layers.mpu.mp_ops import _parallel_embedding
55-
from .communication.comm_utils import ReduceOp
55+
from .communication.group import Group, _add_new_group
56+
from .communication.all_reduce import all_reduce
57+
from .communication.reduce import _get_reduce_op, ReduceOp
5658

5759
__all__ = []
5860

59-
60-
class Group():
61-
"""
62-
The abstract representation of group.
63-
"""
64-
65-
def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
66-
self.rank = rank
67-
self.nranks = rank_num
68-
self.id = id
69-
self.ranks = ranks
70-
self.pg = pg
71-
self.name = name
72-
73-
def is_member(self):
74-
if self.rank < 0:
75-
return False
76-
if self.nranks < 2:
77-
return False
78-
return True
79-
80-
def get_group_rank(self, rank):
81-
if self.is_member() and rank in self.ranks:
82-
return self.ranks.index(rank)
83-
else:
84-
return -1
85-
86-
@property
87-
def process_group(self):
88-
return self.pg
89-
90-
@property
91-
def world_size(self):
92-
return self.nranks if self.rank >= 0 else -1
93-
94-
def __repr__(self):
95-
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
96-
self.rank, self.nranks, self.id)
97-
debug_str += ", ".join(map(str, self.ranks))
98-
debug_str += "; name: "
99-
debug_str += self.name if self.name else "None"
100-
return debug_str
101-
102-
10361
_global_env = None
10462

10563

@@ -147,9 +105,8 @@ def _get_group_map():
147105
global _group_map
148106
if _global_env_gid not in _group_map:
149107
genv = _get_global_env()
150-
_group_map[_global_env_gid] = Group(genv.rank,
151-
genv.world_size,
152-
ranks=list(range(genv.world_size)))
108+
_group_map[_global_env_gid] = Group(genv.rank, 0,
109+
list(range(genv.world_size)))
153110
return _group_map
154111

155112

@@ -197,19 +154,6 @@ def _new_ring_id():
197154
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
198155

199156

200-
def _get_reduce_op(reduce_op, func_name):
201-
if reduce_op == ReduceOp.SUM:
202-
return core.ReduceOp.SUM
203-
elif reduce_op == ReduceOp.MAX:
204-
return core.ReduceOp.MAX
205-
elif reduce_op == ReduceOp.MIN:
206-
return core.ReduceOp.MIN
207-
elif reduce_op == ReduceOp.PROD:
208-
return core.ReduceOp.PRODUCT
209-
else:
210-
raise ValueError("Unknown reduce_op type for {}.".format(func_name))
211-
212-
213157
def get_group(id=0):
214158
"""
215159
@@ -451,10 +395,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
451395
else:
452396
rank = -1
453397
pg = None
454-
group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
398+
group = Group(rank, gid, ranks, pg=pg, name=group_name)
455399
_group_map_by_name[group_name] = group
456400
_group_map[gid] = group
457401
_group_map_backend[group] = backend
402+
#TODO: The method below is a new method for group management, will replace the previous
403+
# three in the future.
404+
_add_new_group(group)
458405

459406
# TODO(shenliang03): This is a temporary solution to solve the problem of
460407
# hang caused by tcp
@@ -476,13 +423,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
476423
ring_id = _new_ring_id()
477424

478425
if global_rank not in ranks:
479-
gp = Group(-1, -1, ring_id, ranks)
426+
gp = Group(-1, ring_id, ranks)
480427
_group_map[ring_id] = gp
481428
else:
482429
ranks = sorted(ranks)
483430
group_rank = ranks.index(global_rank)
484431
group_size = len(ranks)
485-
gp = Group(group_rank, group_size, ring_id, ranks)
432+
gp = Group(group_rank, ring_id, ranks)
486433
_group_map[ring_id] = gp
487434

488435
if group_size >= 2:
@@ -748,104 +695,6 @@ def broadcast(tensor, src, group=None, sync_op=True):
748695
})
749696

750697

751-
def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
752-
"""
753-
754-
Reduce a tensor over all ranks so that all get the result.
755-
As shown below, one process is started with a GPU and the data of this process is represented
756-
by its group rank. The reduce operator is sum. Through all_reduce operator,
757-
each GPU will have the sum of the data from all GPUs.
758-
759-
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
760-
:width: 800
761-
:alt: all_reduce
762-
:align: center
763-
764-
Args:
765-
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
766-
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
767-
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
768-
group (Group, optional): The group instance return by new_group or None for global default group.
769-
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
770-
771-
Returns:
772-
None.
773-
774-
Examples:
775-
.. code-block:: python
776-
777-
# required: distributed
778-
import paddle
779-
import paddle.distributed as dist
780-
781-
dist.init_parallel_env()
782-
if dist.get_rank() == 0:
783-
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
784-
else:
785-
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
786-
dist.all_reduce(data)
787-
print(data)
788-
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
789-
"""
790-
if group is not None and not group.is_member():
791-
return
792-
793-
if in_dygraph_mode():
794-
op_type = _get_reduce_op(op, "all_reduce")
795-
group = _get_default_group() if group is None else group
796-
task = group.process_group.allreduce(tensor, op_type)
797-
if sync_op:
798-
task.wait()
799-
return None
800-
else:
801-
return task
802-
803-
use_calc_stream = sync_op
804-
ring_id = 0 if group is None else group.id
805-
if _non_static_mode():
806-
if op == ReduceOp.SUM:
807-
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
808-
use_calc_stream, 'ring_id',
809-
ring_id)
810-
elif op == ReduceOp.MAX:
811-
return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
812-
use_calc_stream, 'ring_id',
813-
ring_id)
814-
elif op == ReduceOp.MIN:
815-
return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
816-
use_calc_stream, 'ring_id',
817-
ring_id)
818-
elif op == ReduceOp.PROD:
819-
return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
820-
use_calc_stream, 'ring_id',
821-
ring_id)
822-
else:
823-
raise ValueError("Unknown parameter: {}.".format(op))
824-
825-
check_variable_and_dtype(tensor, 'tensor', [
826-
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
827-
'bool'
828-
], 'all_reduce')
829-
if op == ReduceOp.SUM:
830-
op_type = 'c_allreduce_sum'
831-
elif op == ReduceOp.MAX:
832-
op_type = 'c_allreduce_max'
833-
elif op == ReduceOp.MIN:
834-
op_type = 'c_allreduce_min'
835-
elif op == ReduceOp.PROD:
836-
op_type = 'c_allreduce_prod'
837-
if not isinstance(ring_id, int):
838-
raise ValueError("The type of 'ring_id' for all_reduce should be int.")
839-
helper = LayerHelper(op_type, **locals())
840-
helper.append_op(type=op_type,
841-
inputs={'X': [tensor]},
842-
outputs={'Out': [tensor]},
843-
attrs={
844-
'ring_id': ring_id,
845-
'use_calc_stream': use_calc_stream
846-
})
847-
848-
849698
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
850699
"""
851700
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.fluid.framework as framework
17+
from paddle.distributed.communication import stream as stream
18+
from paddle.distributed.communication.reduce import ReduceOp
19+
20+
21+
def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
22+
"""
23+
24+
Reduce a tensor over all ranks so that all get the result.
25+
As shown below, one process is started with a GPU and the data of this process is represented
26+
by its group rank. The reduce operator is sum. Through all_reduce operator,
27+
each GPU will have the sum of the data from all GPUs.
28+
29+
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
30+
:width: 800
31+
:alt: all_reduce
32+
:align: center
33+
34+
Args:
35+
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
36+
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
37+
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
38+
group (Group, optional): The group instance return by new_group or None for global default group.
39+
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
40+
41+
Returns:
42+
Return a task object.
43+
44+
Examples:
45+
.. code-block:: python
46+
47+
# required: distributed
48+
import paddle
49+
import paddle.distributed as dist
50+
51+
dist.init_parallel_env()
52+
if dist.get_rank() == 0:
53+
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
54+
else:
55+
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
56+
dist.all_reduce(data)
57+
print(data)
58+
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
59+
"""
60+
if not framework._in_legacy_dygraph():
61+
return stream.all_reduce(tensor,
62+
op=op,
63+
group=group,
64+
sync_op=sync_op,
65+
use_calc_stream=False)
66+
67+
# code below will be removed after we remove the old dygraph
68+
use_calc_stream = sync_op
69+
ring_id = 0 if group is None else group.id
70+
if op == ReduceOp.SUM:
71+
return paddle._legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
72+
use_calc_stream, 'ring_id',
73+
ring_id)
74+
elif op == ReduceOp.MAX:
75+
return paddle._legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
76+
use_calc_stream, 'ring_id',
77+
ring_id)
78+
elif op == ReduceOp.MIN:
79+
return paddle._legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
80+
use_calc_stream, 'ring_id',
81+
ring_id)
82+
elif op == ReduceOp.PROD:
83+
return paddle._legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
84+
use_calc_stream,
85+
'ring_id', ring_id)
86+
else:
87+
raise ValueError("Unknown parameter: {}.".format(op))

0 commit comments

Comments
 (0)